import os
import sys
from enum import Enum

from sqlalchemy import create_engine, text
from datetime import datetime, timedelta
from typing import Dict
from sqlalchemy.engine import Engine

sys.path.append(os.path.dirname(sys.path[0]))


class DbTypes(Enum):
    """
    导出数据库类型
    """
    mysql = "mysql"
    postgresql_test = "postgresql_test"
    postgresql = "postgresql"
    postgresql_14 = "postgresql_14"
    postgresql_16 = "postgresql_16"
    postgresql_8 = "postgresql_8"
    postgresql_cluster = "postgresql_cluster"
    adv_us = "adv_us"
    adv_uk = "adv_uk"
    adv_other = "adv_other"
    srs = "srs"


class DBUtil(object):
    __SITE_SET__ = {'us', 'uk', 'de', 'fr', 'es', 'it', 'au', 'ca'}

    # mysql连接参数
    __mysql_host__ = "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com"
    __mysql_port__ = "3306"
    __mysql_username__ = "adv_test"
    __mysql_pwd__ = "Yswg%40XP_test"

    # pg连接参数正式库-h15
    __pgsql_host__ = "192.168.10.224"
    __pgsql_port__ = "5433"
    # __pgsql_username__ = "postgres"
    # __pgsql_pwd__ = "fazAqRRVV9vDmwDNRNb593ht5TxYVrfTyHJSJ3BS"
    __pgsql_username__ = "yswg_postgres"
    __pgsql_pwd__ = "yswg_postgres"

    # pg连接参数正式库-h14
    __pgsql_h14_host__ = "192.168.10.223"
    __pgsql_h14_port__ = "5433"
    __pgsql_h14_username__ = "yswg_postgres"
    __pgsql_h14_pwd__ = "yswg_postgres"

    # pg测试库连接参数
    __pgsql_test_host__ = "192.168.10.217"
    __pgsql_test_port__ = "5433"
    __pgsql_test_username__ = "postgres"
    __pgsql_test_pwd__ = "fazAqRRVV9vDmwDNRNb593ht5TxYVrfTyHJSJ3BS"

    # pg集群库连接参数
    # __pgsql_cluster_host__ = "192.168.10.221"
    __pgsql_cluster_host__ = "192.168.10.155"
    __pgsql_cluster_port__ = "6432"
    # __pgsql_cluster_username__ = "postgres"
    # __pgsql_cluster_pwd__ = "fazAqRRVV9vDmwDNRNb593ht5TxYVrfTyHJSJ3BS"
    __pgsql_cluster_username__ = "yswg_postgres"
    __pgsql_cluster_pwd__ = "yswg_postgres"

    # 广告系统us站点阿里云mysql连接参数
    __adv_us_host__ = "rm-wz93hkkmzm5f3868iho.mysql.rds.aliyuncs.com"
    __adv_us_port__ = "3306"
    __adv_us_username__ = "adv_yswg"
    __adv_us_pwd__ = "S4FeR09bFF441lTz"

    # 广告系统uk站点阿里云mysql连接参数
    __adv_uk_host__ = "rm-wz9rrjr9ub877ue8bko.mysql.rds.aliyuncs.com"
    __adv_uk_port__ = "3306"
    __adv_uk_username__ = "adv_yswg"
    __adv_uk_pwd__ = "S4FeR09bFF441lTz"

    # 广告系统小站点mysql连接参数
    __adv_other_host__ = "192.168.10.226"
    __adv_other_port__ = "3306"
    __adv_other_username__ = "root"
    __adv_other_pwd__ = "8cUrBPb0IMY1hfBy"

    # pg8连接参数正式库-h8
    __pgsql_h8_host__ = "192.168.10.210"
    __pgsql_h8_port__ = "5432"
    __pgsql_h8_username__ = "postgres"
    __pgsql_h8_pwd__ = "T#4$4%qPbR7mJx"

    # pg16
    __pgsql_h16_host__ = "192.168.10.225"
    __pgsql_h16_port__ = "5433"
    __pgsql_h16_username__ = "yswg_postgres"
    __pgsql_h16_pwd__ = "yswg_postgres"

    # star_rocks
    __srs_cluster_host__ = "192.168.10.151"
    __srs_cluster_port__ = "19030"
    __srs_cluster_username__ = "spark"
    __srs_cluster_pwd__ = "yswg123"

    @staticmethod
    def get_connection_info(db_type: str, site_name: str):
        """
        根据不同站点获取不同的访问db库连接配置
        :param db_type: 连接的数据库类型
        :param site_name: 站点名称
        :return: db连接配置
        """
        assert db_type is not None, "db_type 不能为空！！"
        assert site_name is not None, "站点不能为空！！"
        site_db_map = {
            "us": "selection",
            "uk": "selection_uk",
            "de": "selection_de",
            "es": "selection_es",
            "fr": "selection_fr",
            "it": "selection_it"
        }

        adv_site_db_map = {
            "us": "advertising_manager",
            "uk": "advertising_manager",
            "de": "advertising_manager",
            "es": "advertising_manager",
            "fr": "advertising_manager",
            "it": "advertising_manager"
        }

        if db_type == DbTypes.mysql.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "mysql",
                "url": f"jdbc:mysql://{DBUtil.__mysql_host__}:{DBUtil.__mysql_port__}/{db_name}",
                "username": DBUtil.__mysql_username__,
                "pwd": DBUtil.__mysql_pwd__,
                "host": DBUtil.__mysql_host__,
                "port": DBUtil.__mysql_port__,
            }

        if db_type == DbTypes.srs.name:
            db_name = site_db_map[site_name]
            return {
                # "db_name": site_db_map[site_name],
                "db_name": "selection",
                "db_type": "mysql",
                "url": f"jdbc:mysql://{DBUtil.__srs_cluster_host__}:{DBUtil.__srs_cluster_port__}/{db_name}",
                "username": DBUtil.__srs_cluster_username__,
                "pwd": DBUtil.__srs_cluster_pwd__,
                "host": DBUtil.__srs_cluster_host__,
                "port": DBUtil.__srs_cluster_port__,
            }

        if db_type == DbTypes.postgresql.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_host__}:{DBUtil.__pgsql_port__}/{db_name}",
                "username": DBUtil.__pgsql_username__,
                "host": DBUtil.__pgsql_host__,
                "port": DBUtil.__pgsql_port__,
                "pwd": DBUtil.__pgsql_pwd__,
            }

        if db_type == DbTypes.postgresql_14.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_h14_host__}:{DBUtil.__pgsql_h14_port__}/{db_name}",
                "username": DBUtil.__pgsql_h14_username__,
                "host": DBUtil.__pgsql_h14_host__,
                "port": DBUtil.__pgsql_h14_port__,
                "pwd": DBUtil.__pgsql_h14_pwd__,
            }

        if db_type == DbTypes.postgresql_8.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_h8_host__}:{DBUtil.__pgsql_h8_port__}/{db_name}",
                "username": DBUtil.__pgsql_h8_username__,
                "host": DBUtil.__pgsql_h8_host__,
                "port": DBUtil.__pgsql_h8_port__,
                "pwd": DBUtil.__pgsql_h8_pwd__,
            }

        if db_type == DbTypes.postgresql_16.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_h16_host__}:{DBUtil.__pgsql_h16_port__}/{db_name}",
                "username": DBUtil.__pgsql_h16_username__,
                "host": DBUtil.__pgsql_h16_host__,
                "port": DBUtil.__pgsql_h16_port__,
                "pwd": DBUtil.__pgsql_h16_pwd__,
            }

        if db_type == DbTypes.postgresql_test.name:
            db_name = site_db_map[site_name]
            return {
                "db_name": site_db_map[site_name],
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_test_host__}:{DBUtil.__pgsql_test_port__}/{db_name}",
                "username": DBUtil.__pgsql_test_username__,
                "pwd": DBUtil.__pgsql_test_pwd__,
                "host": DBUtil.__pgsql_test_host__,
                "port": DBUtil.__pgsql_test_port__,
            }

        if db_type == DbTypes.postgresql_cluster.name:
            # 注意集群不分库
            db_name = "postgres"
            return {
                "db_name": db_name,
                "db_type": "postgresql",
                "url": f"jdbc:postgresql://{DBUtil.__pgsql_cluster_host__}:{DBUtil.__pgsql_cluster_port__}/{db_name}",
                "username": DBUtil.__pgsql_cluster_username__,
                "pwd": DBUtil.__pgsql_cluster_pwd__,
                "host": DBUtil.__pgsql_cluster_host__,
                "port": DBUtil.__pgsql_cluster_port__,
            }

        if db_type == DbTypes.adv_us.name:
            db_name = adv_site_db_map[site_name]
            return {
                "db_name": adv_site_db_map[site_name],
                "db_type": "adv_us",
                "url": f"jdbc:mysql://{DBUtil.__adv_us_host__}:{DBUtil.__adv_us_port__}/{db_name}",
                "username": DBUtil.__adv_us_username__,
                "pwd": DBUtil.__adv_us_pwd__,
                "host": DBUtil.__adv_us_host__,
                "port": DBUtil.__adv_us_port__,
            }

        if db_type == DbTypes.adv_uk.name:
            db_name = adv_site_db_map[site_name]
            return {
                "db_name": adv_site_db_map[site_name],
                "db_type": "adv_uk",
                "url": f"jdbc:mysql://{DBUtil.__adv_uk_host__}:{DBUtil.__adv_uk_port__}/{db_name}",
                "username": DBUtil.__adv_uk_username__,
                "pwd": DBUtil.__adv_uk_pwd__,
                "host": DBUtil.__adv_uk_host__,
                "port": DBUtil.__adv_uk_port__,
            }

        if db_type == DbTypes.adv_other.name:
            db_name = adv_site_db_map[site_name]
            return {
                "db_name": adv_site_db_map[site_name],
                "db_type": "adv_other",
                "url": f"jdbc:mysql://{DBUtil.__adv_other_host__}:{DBUtil.__adv_other_port__}/{db_name}",
                "username": DBUtil.__adv_other_username__,
                "pwd": DBUtil.__adv_other_pwd__,
                "host": DBUtil.__adv_other_host__,
                "port": DBUtil.__adv_other_port__,
            }

    @staticmethod
    def get_db_engine(db_type, site_name) -> Engine:
        """
        根据不同站点和数据库类型，获取数据库连接对象
        :param db_type: 连接的数据库类型
        :param site_name: 站点名称
        :return: engine: 数据库连接对象
        """
        conn = DBUtil.get_connection_info(db_type, site_name)
        assert conn is not None, "获取连接配置信息错误，请检查"
        if conn['db_type'] == 'postgresql':
            engine = create_engine(
                f"postgresql://{conn['username']}:{conn['pwd']}@{conn['host']}:{conn['port']}/{conn['db_name']}",
                pool_size=10,  # 连接池的大小
                max_overflow=20  # 超出连接池大小之外可以创建的连接数
            )
        # 获取数据库连接
        elif conn['db_type'] == 'mysql':
            engine = create_engine(
                f"mysql+pymysql://{conn['username']}:{conn['pwd']}@{conn['host']}:{conn['port']}/{conn['db_name']}",
                pool_size=10,  # 连接池的大小
                max_overflow=20  # 超出连接池大小之外可以创建的连接数
            )
        else:
            raise Exception("数据库不支持")
        return engine

    @staticmethod
    def exec_sql(db_type, site_name, sql: str, dispose_flag=False):
        """
        根据不同站点获取不同的访问db库获取连接并执行sql sql 多条sql以分号；隔开
        :param db_type: 连接的数据库类型
        :param site_name: 站点名称
        :param sql: 需要执行的sql语句
        :param dispose_flag: 执行完毕后是否释放销毁
        """
        engine = DBUtil.get_db_engine(db_type, site_name)
        with engine.connect() as connection:
            for linesql in sql.split(";"):
                linesql = linesql.strip()
                if len(linesql) > 0:
                    print("==========================执行sql中=====================================")
                    print(linesql)
                    connection.execute(linesql)
        print("==========================sql执行完毕=====================================")
        if dispose_flag:
            engine.dispose()
        pass

    @staticmethod
    def engine_exec_sql(engine, sql: str):
        """
         根据已经建立的数据库engine执行sql
        :param engine: 数据库连接对象
        :param sql: 需要执行的sql语句
        """
        print("==========================执行sql中=====================================")
        print(sql)
        with engine.connect() as connection:
            val = connection.execute(text(sql.strip()))
            print("==========================sql执行完毕=====================================")
            return val
        pass

    @classmethod
    def copy_pg_tb_index(cls, engine: Engine, source_tb_name: str, target_tb_name: str):
        """
        copy source_tb_name 表的索引生成建立索引语句 注意target_tb_name的索引必须为空 不然可能触发重复建立索引的情况
        :param engine:
        :param source_tb_name:
        :param target_tb_name:
        :return:
        """
        indexs = []
        indexs_cp = []
        with engine.connect() as connection:
            sql = f"""select schemaname, tablename, indexname, indexdef from pg_indexes  where tablename = '{source_tb_name}'"""
            rows = connection.execute(sql)
            i = 0
            for row in rows:
                i = i + 1
                # 判断是否是重复索引 todo
                indexdef = str(row['indexdef'])
                schemaname = row['schemaname']
                indexname = row['indexname']
                suffix = indexdef[indexdef.rfind("USING"):]
                prefix = indexdef[0:indexdef.rfind("INDEX")]
                fields = suffix[suffix.rfind("(") + 1: suffix.rfind(")")].split(",")
                # tmp = "_".join(list(map(lambda it: str(it).strip(), fields)))
                pk = "_pk" if str(indexname).endswith("pk") else ""
                sql = f"{prefix} index {source_tb_name}_idx{i}_{datetime.now().strftime('%m_%d_%H%M')}{pk} on {schemaname}.{target_tb_name} {suffix};".lower()
                indexs_cp.append(sql)
                indexs.append({
                    "schemaname": schemaname,
                    "indexname": row['indexname'],
                    "tablename": row['tablename'],
                    "suffix": suffix,
                    "fields": fields
                })

        return indexs, indexs_cp

    @classmethod
    def exchange_tb(cls, engine: Engine, source_tb_name: str, target_tb_name: str, cp_index_flag: bool):
        """
        交换表名 注意两个表必须都是普通表 而不是分区表
        :param engine: 引擎
        :param source_tb_name:原始表
        :param target_tb_name:目标表
        :param copy_index_flag:是否创建索引
        :return:
        """
        import random
        index_cp_sql = None
        if cp_index_flag:
            indexs, indexs_cp = cls.copy_pg_tb_index(engine, target_tb_name, source_tb_name)
            if len(indexs_cp) > 0:
                index_cp_sql = "\n".join(indexs_cp)

        with engine.connect() as connection:
            # 先重新构建索引
            if index_cp_sql is not None:
                print("================================重新构建索引中================================")
                print(index_cp_sql)
                connection.execute(index_cp_sql)
            suffix_int = random.randint(1, 200)
            sql = f"""
                                alter table {target_tb_name} rename to {target_tb_name}_back_{suffix_int};
                                alter table {source_tb_name} rename to {target_tb_name};
                                alter table {target_tb_name}_back_{suffix_int} rename to {source_tb_name};
                            """
            print(f"================================交换表{source_tb_name}到{target_tb_name}中================================")
            for line in sql.strip().split("\n"):
                line = line.strip()
                if len(line) > 2:
                    print(line)
                    connection.execute(text(line))
                pass
        return True

    @classmethod
    def add_pg_part(cls, engine: Engine,
                    source_tb_name: str,
                    part_master_tb: str,
                    cp_index_flag: bool,
                    part_val: Dict,
                    print_flag: bool = False
                    ):
        """
        把普通表 source_tb_name 添加到分区表
        :param engine:  引擎
        :param source_tb_name:普通表
        :param part_master_tb: 分区表母表
        :param cp_index_flag: 是否在复制master的索引
        :param part_val: 分区值 from  to key 的dict
        :return:
        """
        val1 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['from'])))
        val2 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['to'])))
        index_cp_sql = None
        if cp_index_flag:
            indexs, indexs_cp = cls.copy_pg_tb_index(engine, part_master_tb, source_tb_name)
            if len(indexs_cp) > 0:
                index_cp_sql = "\n".join(indexs_cp)

        add_sql = f"""alter table {part_master_tb} attach partition {source_tb_name} for values from ({val1}) to ({val2});"""
        # 只是打印
        if print_flag:
            if index_cp_sql is not None:
                print("================================构建索引如下================================")
                print(index_cp_sql)

            print("================================普通表加入分区sql如下================================")
            print(add_sql)
            return

        # 真正执行
        with engine.connect() as connection:
            # 先重新构建索引
            if index_cp_sql is not None:
                print("================================构建索引中================================")
                print(index_cp_sql)
                connection.execute(index_cp_sql)
            print(f"==========================加入普通表{source_tb_name}到分区表{part_master_tb}中===========================")
            print(add_sql)
            connection.execute(add_sql)
        return True

    @classmethod
    def exchange_pg_part_distributed_tb(cls, engine: Engine,
                                        source_tb_name: str,
                                        part_master_tb: str,
                                        part_target_tb: str,
                                        part_val: Dict,
                                        drop_old: bool = True
                                        ):
        """
        交换普通表 source_tb_name 到分布式的分区表的指定分区 如果分区表的分区不存在则直接添加进去 存在则交换
        :param engine:  引擎
        :param source_tb_name:要加入到分区母表的普通表
        :param part_master_tb: 分区表母表
        :param part_target_tb: 要加入到分区表的目标分区子表
        :param part_val: 分区值 from  to key 的dict
        :return:
        """
        # 先判断目标分区是不是已存在
        val = DBUtil.engine_exec_sql(engine, f"""
        SELECT * FROM pg_partition_tree('{part_master_tb}') where relid::varchar='{part_target_tb}'
""")
        exist_flag = len(list(val)) > 0
        if exist_flag:
            DBUtil.engine_exec_sql(engine, f"""alter table {part_master_tb} detach partition {part_target_tb}""")
            DBUtil.engine_exec_sql(engine, f"""alter table {part_target_tb} rename to {part_target_tb}_back""")
            pass

        if source_tb_name != part_target_tb:
            DBUtil.engine_exec_sql(engine, f"""alter table {source_tb_name} rename to {part_target_tb};""")

        val1 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['from'])))
        val2 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['to'])))

        # 新加入的表重命名并加入到分区中
        DBUtil.engine_exec_sql(engine, f"""
        alter table {part_master_tb} attach partition {part_target_tb} for values from ({val1}) to ({val2});
""")
        if drop_old:
            DBUtil.engine_exec_sql(engine, f"""drop table  if exists {part_target_tb}_back """)
        else:
            # 交换back
            DBUtil.engine_exec_sql(engine, f"""alter table {part_target_tb}_back rename to {source_tb_name}""")

        print(f"==================表{source_tb_name}加入成功==================================")
        pass

    @classmethod
    def exchange_pg_part_tb(cls, engine: Engine,
                            source_tb_name: str,
                            part_master_tb: str,
                            part_target_tb: str,
                            cp_index_flag: bool,
                            part_val: Dict,
                            print_flag: bool = False
                            ):
        """
        交换普通表 source_tb_name 到分区表
        :param engine:  引擎
        :param source_tb_name:普通表
        :param part_master_tb: 分区表母表
        :param part_target_tb: 分区表目标表
        :param cp_index_flag: 是否在源表复制创建索引
        :param part_val: 分区值 from  to key 的dict
        :return:
        """

        val1 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['from'])))
        val2 = ",".join(list(map(lambda it: f"'{str(it).strip()}'", part_val['to'])))

        index_cp_sql = None
        if cp_index_flag:
            indexs, indexs_cp = cls.copy_pg_tb_index(engine, part_target_tb, source_tb_name)
            if len(indexs_cp) > 0:
                index_cp_sql = "\n".join(indexs_cp)

        exchange_sql = f"""
                        alter table {part_master_tb} detach partition {part_target_tb};
                        alter table {part_master_tb} attach partition {source_tb_name} for values from ({val1}) to ({val2});
                        alter table {part_target_tb} rename to {part_target_tb}_back;
                        alter table {source_tb_name} rename to {part_target_tb};
                        alter table {part_target_tb}_back rename to {source_tb_name};
                        """

        # 只是打印
        if print_flag:
            if index_cp_sql is not None:
                print("================================构建索引如下================================")
                print(index_cp_sql)

            print("================================构建交换sql如下================================")
            print(exchange_sql)
            return

        # 真正执行
        with engine.connect() as connection:
            # 先重新构建索引
            if index_cp_sql is not None:
                print("================================重新构建索引中================================")
                print(index_cp_sql)
                connection.execute(index_cp_sql)
            # 1. 旧分区脱离分区主表
            # 2.新普通表加入分区主表
            # 3.修改分区表名
            print(
                f"================================交换普通表{source_tb_name}到分区表{part_master_tb}中================================")
            print(exchange_sql)
            connection.execute(exchange_sql)
        return True

    @classmethod
    def list_partition(cls, engine: Engine, tb_name):
        """
        查询pgsql分区表信息
        :param engine:
        :param tb_name:
        :return:
        """
        sql = f"""
    select pt.relname                                 as partition_name,
   pg_get_expr(pt.relpartbound, pt.oid, true) as partition_expression
from pg_class base_tb
     join pg_inherits i on i.inhparent = base_tb.oid
     join pg_class pt on pt.oid = i.inhrelid
where base_tb.oid = '{tb_name}'::regclass;
        """
        with engine.connect() as connection:
            rows = connection.execute(sql)
            return list(rows)
