import calendar
import json
import os
import sys
from datetime import timedelta
from enum import Enum
from typing import Dict

import requests
from textblob import Word

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql import functions as F
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
from pyspark.sql.types import *
from pyspark.sql import Window

from utils.db_util import DBUtil
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from utils.ssh_util import SSHUtil
from utils.es_util import EsUtils
from utils.DolphinschedulerHelper import DolphinschedulerHelper
from datetime import datetime
from yswg_utils.common_udf import udf_parse_amazon_orders
from utils.StarRocksHelper import StarRocksHelper

class DateTypes(Enum):
    """
    日期格式枚举
    """
    day = "day"
    last30day = "last30day"
    week = "week"
    month = "month"
    month_week = "month_week"
    month_old = "month_old"
    last365day = "last365day"
    year = "year"


class CommonUtil(object):
    __SITE_SET__ = {'us', 'uk', 'de', 'fr', 'es', 'it', 'au', 'ca'}

    __sqoop_home__ = "/opt/module/sqoop-1.4.6/bin/sqoop"
    __sqoop_1_4_7_home__ = "/mnt/opt/module/sqoop-1.4.7/bin/sqoop"
    __python_home__ = "/opt/module/anaconda3/envs/pyspark/bin/python3.8"
    __desploy_home__ = "/opt/module/spark/demo/py_demo"

    # __hive_home__ = "/opt/module/hive/bin/hive"
    __hive_home__ = "/opt/datasophon/hive-3.1.0/bin/hive"
    __hadoop_home__ = "/opt/module/hadoop/bin/hadoop"

    __msg_usr__ = ['wujicang', 'huangjian', 'fangxingjun', 'chenjianyun', 'wangrui4']

    _date_time_format = "yyyy-MM-dd HH:mm:ss"

    _py_date_time_format = '%Y-%m-%d %H:%M:%S'

    _py_date_format = '%Y-%m-%d'

    __start_time_of_morning_str__ = "08:00"
    __end_time_of_morning_str__ = "12:30"
    __start_time_of_afternoon_str__ = "14:00"
    __end_time_of_afternoon_str__ = "19:00"

    __export_process_map__ = {
        "新ABA流程": ['dwt_aba_st_analytics.py', 'dwd_st_volume_fba.py', 'dwt_aba_st_analytics_report_pg.py',
                   'dwt_aba_last_change_rate.py', 'dwt_st_market_pg.py'],
        "反查搜索词": ['dwt_st_asin_reverse.py'],
        "店铺流程": ['dwt_fb_asin_info.py', 'dwt_fb_base_report.py', 'dwt_fb_category_report.py', 'dwt_fb_top20_asin_info.py'],
        "流量选品": ['es_flow_asin.py']
    }
    u_parse_amazon_orders = F.udf(udf_parse_amazon_orders, IntegerType())

    """
    一般工具类
    """

    @classmethod
    def to_int(cls, obj, defval=None):
        """
        安全转为 int
        :param obj:
        :param defval: 默认值
        :return:
        """
        if CommonUtil.notBlank(obj):
            return int(obj)
        return defval

    @classmethod
    def to_float(cls, obj, defval=None):
        """
        安全转为 float
        :param obj:
        :param defval:默认值
        :return:
        """
        if CommonUtil.notBlank(obj):
            return float(obj)
        return defval

    @classmethod
    def to_str(cls, obj, defval=None):
        """
        安全转为 str
        :param obj:
        :param defval:默认值
        :return:
        """
        if CommonUtil.notBlank(obj):
            return str(obj)
        return defval

    @staticmethod
    def reset_partitions(site_name, partitions_num=10):
        """
        按不同站点划分分区数量
        :param site_name: 站点名称
        :param partitions_num: 自定义分区数量
        :return: partitions_num
        """
        print("重置分区数")
        if site_name in ['us']:
            partitions_num = partitions_num
        elif site_name in ['uk', 'de']:
            partitions_num = partitions_num // 2 if partitions_num // 2 > 0 else 1
        elif site_name in ['es', 'fr', 'it']:
            partitions_num = partitions_num // 4 if partitions_num // 4 > 0 else 1
        return partitions_num

    @staticmethod
    def split_month_week_date(date_type, date_info):
        """
        对month类型和week类型date_info进行拆解
        :param date_type: 分区类型
        :param date_info: 入参时间
        :return: d1 d2
        """
        if date_type == DateTypes.week.name:
            year, week = date_info.split('-')
            return int(year), int(week)
        elif date_type == DateTypes.month.name:
            year, month = date_info.split('-')
            return int(year), int(month)
        elif date_type == DateTypes.month_week.name:
            year, month = date_info.split('-')
            return int(year), int(month)

    @staticmethod
    def safeIndex(list: list, index: int, default: object = None):
        """
        安全获取list的索引对应的值
        :param list: 列表
        :param index: 索引
        :param default: 默认值
        :return:
        """
        if (index <= len(list) - 1):
            return list[index]
        return default

    @staticmethod
    def get_calDay_by_dateInfo(spark_session: SparkSession, date_type: str, date_info: str):
        """
        根据不同日期维度,获取当前维度下的最后一天
        :param spark_session: sparksession对象
        :param date_type: 日期类型
        :param date_info: 日期值
        :return: cal_date:根据不同日期维度,获取当前维度下的最后一天
        """
        assert date_type is not None, "date_type不能为空!"
        assert date_info is not None, "date_info不能为空!"
        df_date = spark_session.sql(f"select * from dim_date_20_to_30;")
        df = df_date.toPandas()
        if date_type in [DateTypes.day.name, DateTypes.last30day.name]:
            cal_day = date_info
        # 如果为 周、月则取该周、月的最后一日,作为新品计算基准日
        elif date_type in [DateTypes.week.name, DateTypes.month.name]:
            sorted_df = df.loc[df[f'year_{date_type}'] == f"{date_info}"].sort_values('date', ascending=False)
            cal_day = sorted_df.head(1)['date'].iloc[0]
        elif date_type == '4_week':
            sorted_df = df.loc[(df.year_week == f"{date_info}")].sort_values('date', ascending=False)
            cal_day = sorted_df.head(1)['date'].iloc[0]
        elif date_type == DateTypes.month_week.name:
            current_date = datetime.now().date()
            cal_day = current_date.strftime("%Y-%m-%d")
        else:
            return None
        print("cal_day:", str(cal_day))
        return str(cal_day)

    @staticmethod
    def get_rel_exception_info():
        import sys
        exc_type, exc_value, exc_traceback = sys.exc_info()
        print(exc_traceback)
        return exc_type, f"""{exc_value} in {exc_traceback.tb_frame.f_code.co_filename}  {exc_traceback.tb_lineno} line"""

    @staticmethod
    def get_sys_arg(index: int, defVal: object):
        """
        获取main系统输入参数脚标从1开始
        :param index: 索引
        :param defVal: 默认值
        :return:
        """
        return CommonUtil.safeIndex(sys.argv, index, defVal)

    @staticmethod
    def listNotNone(listVal: list = None):
        """
        判断是否是空数组
        """
        return listVal is not None or len(listVal) > 0

    @staticmethod
    def notNone(obj: object = None):
        """
        判断是否是None
        """
        return obj is not None

    @staticmethod
    def notBlank(strVal: str = None):
        """
        判断是否是空字符串
        """
        return strVal is not None and strVal != ''

    @staticmethod
    def get_day_offset(day: str, offset: int):
        """
        获取日期偏移值
        :param day: 类似 2022-11-01
        :param offset: 偏移值
        :return: 过去或将来的时间
        """
        pattern = "%Y-%m-%d"
        rel_day = datetime.strptime(day, pattern)
        d = rel_day + timedelta(days=offset)
        return d.strftime(pattern)

    @staticmethod
    def get_month_offset(month: str, offset: int):
        """
        获取月份偏移值
        :param month: 类似 2022-11
        :param offset: 偏移值
        :return: 过去或将来的月份
        """
        year_int = int(CommonUtil.safeIndex(month.split("-"), 0, None))
        month_int = int(CommonUtil.safeIndex(month.split("-"), 1, None))

        if offset > 0:
            for i in range(0, offset):
                year_int, month_int = calendar._nextmonth(year_int, month_int)

        if offset < 0:
            for i in range(0, abs(offset)):
                year_int, month_int = calendar._prevmonth(year_int, month_int)

        return datetime(year_int, month_int, 1).strftime("%Y-%m")

    @staticmethod
    def reformat_date(date_str: str, from_format: str, to_format: str):
        """
        重新格式化日期
        :param date_str:
        :param from_format:
        :param to_format:
        :return:
        """
        return datetime.strptime(date_str, from_format).strftime(to_format)

    @staticmethod
    def format_now(from_format: str):
        now = datetime.now()
        return datetime.strftime(now, from_format)

    @staticmethod
    def format_timestamp(timestamp: int, format: str = _py_date_time_format):
        """
        格式化毫秒级别时间戳
        :param timestamp:
        :param format:
        :return:
        """
        from datetime import datetime
        return datetime.strftime(datetime.fromtimestamp(timestamp / 1000), format)

    @staticmethod
    def calculate_date_offset(date1, date2):
        """
        计算日期偏移量
        :param date1: 日期1 格式:%Y-%m-%d
        :param date2: 日期2 格式:%Y-%m-%d
        :return: 日期差值
        """
        if date1 is None or date2 is None:
            return None

        date_format = "%Y-%m-%d"
        try:
            # 将日期字符串转换为 datetime 对象
            datetime1 = datetime.strptime(date1, date_format)
            datetime2 = datetime.strptime(date2, date_format)

            # 计算日期的偏移量
            offset = abs((datetime2 - datetime1).days)
            return offset
        except ValueError:
            # 日期字符串格式不正确
            return None

    @staticmethod
    def list_to_insql(arr: list):
        """
        数组转为in中的sql
        :param arr:
        :return:
        """
        return str.join(",", list(map(lambda item: f"'{item}'", arr)))

    @staticmethod
    def arr_to_spark_col(arr: list):
        """
        python数组转为df中的数组
        """
        return F.array(list(map(lambda item: F.lit(item), arr)))

    @staticmethod
    def build_es_option(site_name="us"):
        """
        构建spark导出用 es参数
        """
        site_port = {
            "us": 9200,
            "uk": 9201,
            "de": 9201,
        }
        return {
            "es.nodes": "120.79.147.190",
            "es.port": site_port[site_name],
            "es.net.http.auth.user": "elastic",
            "es.net.http.auth.pass": "selection2021.+",
            "es.nodes.wan.only": True,
            "es.index.auto.create": True
        }

    @staticmethod
    def str_compress(strLines: str):
        """
        多行字符串压缩
        :param strLines:
        :return:
        """
        strArr = []
        splitArr = strLines.splitlines()
        for s in splitArr:
            strArr.append(s.strip())
        return ' '.join(strArr)

    @staticmethod
    def build_export_sh(site_name: str,
                        db_type: str,
                        hive_tb: str,
                        export_tb: str,
                        col: list,
                        partition_dict: dict,
                        num_mappers=20
                        ):

        conn_info = DBUtil.get_connection_info(db_type, site_name)
        cols = str.join(",", col)
        if len(partition_dict.keys()) > 0:
            p_keys = str.join(",", partition_dict.keys())
            p_values = str.join(",", partition_dict.values())
            return f"""
    {CommonUtil.__sqoop_home__} export -D mapred.job.queue.name=default -D mapred.task.timeout=0 \\
    --connect {conn_info['url']}  \\
    --username {conn_info['username']}  \\
    --password {conn_info['pwd']} \\
    --table {export_tb}  \\
    --input-fields-terminated-by '\\001'  \\
    --hcatalog-database big_data_selection  \\
    --hcatalog-table {hive_tb}  \\
    --hcatalog-partition-keys {p_keys}  \\
    --hcatalog-partition-values {p_values}  \\
    --input-null-string '\\\\N'  \\
    --input-null-non-string '\\\\N'  \\
    --num-mappers {num_mappers} \\
    --columns {cols}  \\
    --outdir "/tmp/sqoop/"
    """
        return f"""
        {CommonUtil.__sqoop_home__} export -D mapred.job.queue.name=default -D mapred.task.timeout=0 \\
    --connect {conn_info['url']}  \\
    --username {conn_info['username']}  \\
    --password {conn_info['pwd']} \\
    --table {export_tb}  \\
    --input-fields-terminated-by '\\001'  \\
    --hcatalog-database big_data_selection  \\
    --hcatalog-table {hive_tb}  \\
    --input-null-string '\\\\N'  \\
    --input-null-non-string '\\\\N'  \\
    --num-mappers {num_mappers} \\
    --columns {cols}  \\
    --outdir "/tmp/sqoop/"
"""

    @staticmethod
    def build_import_sh_tmp_inner(conn_info: Dict,
                                  query: str,
                                  hive_tb_name: str,
                                  map_num: int = 1,
                                  split_by: str = None
                                  ):
        """
        直接导入到临时内部临时表用于一次性计算用
        :param conn_info:
        :param query:
        :param hive_tb_name:
        :return:
        """
        default_db = 'big_data_selection'
        cmd = f"""
    {CommonUtil.__sqoop_home__} yswg_import -D mapred.job.queue.name=default -D mapred.task.timeout=0 \\
    --connect {conn_info['url']}  \\
    --username {conn_info['username']}  \\
    --password {conn_info['pwd']} \\
    --query "{query}"  \\
    --mapreduce-job-name f"sqoop_task{hive_tb_name}"  \\
    --hcatalog-database {default_db} \\
    --create-hcatalog-table \\
    --hcatalog-table {hive_tb_name} \\
    --fields-terminated-by '\\t'  \\
    --hive-drop-import-delims  \\
    --null-string '\\\\N'  \\
    --null-non-string '\\\\N'  \\
    --m {map_num} \\
    --split-by {split_by} \\
    --outdir "/tmp/sqoop/"
    """
        return cmd

    @staticmethod
    def build_hive_import_sh(site_name: str,
                             db_type: str,
                             query: str,
                             hive_table: str,
                             partition_dict: dict
                             ):
        """
        导入到 hive 内部表指定分区 注意使用 orcfile 格式进行压缩
        """
        default_db = 'big_data_selection'
        conn_info = DBUtil.get_connection_info(db_type, site_name)
        #  对query中的特殊字符自动转义
        query = query.strip()
        query = query.replace("`", r"\`")
        keys = ",".join(partition_dict.keys())
        values = ",".join(partition_dict.values())

        return f"""
    {CommonUtil.__sqoop_home__} yswg_import -D mapred.job.queue.name=default -D mapred.task.timeout=0 \\
    --connect {conn_info['url']}  \\
    --username {conn_info['username']}  \\
    --password {conn_info['pwd']} \\
    --query "{query}"  \\
    --mapreduce-job-name f"sqoop_task{hive_table}"  \\
    --hcatalog-database {default_db} \\
    --hcatalog-table {hive_table} \\
    --hcatalog-partition-keys {keys} \\
    --hcatalog-partition-values {values} \\
    --hcatalog-storage-stanza "stored as orcfile" \\
    --m 1 \\
    --outdir "/tmp/sqoop/"
    """

    @staticmethod
    def build_import_sh(site_name: str,
                        db_type: str,
                        query: str,
                        hdfs_path: str,
                        map_num: int = 1,
                        key: str = None
                        ):
        """
        导入到hdfs外部表
        :param site_name:
        :param db_type:
        :param query:
        :param hdfs_path:
        :param map_num:
        :param key:
        :return:
        """
        conn_info = DBUtil.get_connection_info(db_type, site_name)
        #  对query中的特殊字符自动转义
        query = query.strip()
        query = query.replace("`", r"\`")
        start_name = CommonUtil.get_start_name_from_hdfs_path(hdfs_path)
        if start_name:
            start_name = "sqoop_task:"+start_name
        else:
            start_name = "sqoop_task"
        return f"""
    {CommonUtil.__sqoop_home__} yswg_import -D mapred.job.queue.name=default -D mapred.task.timeout=0  --append \\
    --connect {conn_info['url']}  \\
    --username {conn_info['username']}  \\
    --password {conn_info['pwd']} \\
    --target-dir {hdfs_path}  \\
    --mapreduce-job-name "{start_name}"  \\
    --query "{query}"  \\
    --fields-terminated-by '\\t'  \\
    --hive-drop-import-delims  \\
    --null-string '\\\\N'  \\
    --null-non-string '\\\\N'  \\
    --compress \\
    -m {map_num} \\
    --split-by {key} \\
    --compression-codec lzop  \\
    --outdir "/tmp/sqoop/"
    """

    @staticmethod
    def after_import(hdfs_path: str, hive_tb: str):
        """
        导入hdfs后对hive表进行压缩和分区修复
        :param hdfs_path:
        :param hive_tb:
        :return:
        """
        cmd = rf"""
    {CommonUtil.__hadoop_home__} jar  \
    /opt/module/hadoop/share/hadoop/common/hadoop-lzo-0.4.20.jar  \
    com.hadoop.compression.lzo.DistributedLzoIndexer -Dmapreduce.job.queuename=default -Dmapreduce.framework.name=local\
    {hdfs_path}
    """
        print("lzo 压缩中")
        print(cmd)
        client = SSHUtil.get_ssh_client()
        SSHUtil.exec_command_async(client, cmd, ignore_err=False)
        print(f"修复表{hive_tb}中")
        cmd = rf"""{CommonUtil.__hive_home__} -e "set hive.msck.path.validation=ignore; msck repair table big_data_selection.{hive_tb};" """
        print(cmd)
        SSHUtil.exec_command_async(client, cmd, ignore_err=False)
        client.close()
        pass

    @staticmethod
    def hive_cmd_exec(cmd: str):
        """
        使用命令行直接执行执行hive命令
        """
        import os
        hive_cmd = rf"""{CommonUtil.__hive_home__} -e '{cmd}' """
        print(f"执行hive命令中{hive_cmd}")
        os.system(hive_cmd)
        # client = SSHUtil.get_ssh_client()
        # SSHUtil.exec_command_async(client, hive_cmd, ignore_err=False)
        # client.close()
        pass

    @staticmethod
    def orctable_concatenate(hive_table: str,
                             partition_dict: Dict,
                             innerFlag: bool = False,
                             min_part_num: int = 5,
                             max_retry_time: int = 10):
        # 查看有多少分区小文件
        path = CommonUtil.build_hdfs_path(hive_table, partition_dict, innerFlag)
        part_list = HdfsUtils.read_list(path)
        if part_list is None:
            return
        retry_time = 0
        partition = []
        for key in partition_dict.keys():
            partition.append(f""" {key}="{partition_dict.get(key)}" """)
        default_db = 'big_data_selection'
        partition_str = ",".join(partition)
        while len(part_list) > min_part_num and retry_time <= max_retry_time:
            # 先进行修复
            # CommonUtil.hive_cmd_exec(f"""msck repair table {default_db}.{hive_table};""")
            if len(partition_dict) == 0:
                # 表进行小文件合并
                CommonUtil.hive_cmd_exec(f"""alter table {default_db}.{hive_table} concatenate;""")
            else:
                # 分区进行小文件合并
                CommonUtil.hive_cmd_exec(f"""alter table {default_db}.{hive_table} partition ({partition_str}) concatenate;""")
            part_list = HdfsUtils.read_list(path)
            pass
        pass

    @staticmethod
    def check_schema(spark_session: SparkSession, df_schema: DataFrame, save_tb_name: str, filter_cols: list = None):
        """
                schema验证,可验证数仓中save_table与传入的df的schema的差异
                :param spark_session: spark任务对象
                :param df_schema: 需要比较的df
                :param save_tb_name: 存储表
                :param filter_cols: 不参与比较的字段过滤,不想参与比较的字段可以写在该list中;
                :return:DataFrame:返回有差异的字段数据的DataFrame
        """
        # 基础不比较的过滤字段
        base_filter_cols = ['site_name', 'date_type', 'date_info']
        sql = f"select * from {save_tb_name} limit 0"
        tb_schema = spark_session.sql(sql).schema
        # filter_cols 用于维护不参与对比的字段
        if filter_cols is None:
            filter_cols = base_filter_cols
        else:
            filter_cols = base_filter_cols.extend(filter_cols)
        list1 = []
        list2 = []
        for item in tb_schema.fields:
            if item.name not in filter_cols:
                list1.append((item.name, item.dataType.simpleString()))

        for item in df_schema.schema.fields:
            if item.name not in filter_cols:
                list2.append((item.name, item.dataType.simpleString()))

        df1 = spark_session.createDataFrame(list1, ('name', 'type'))
        df2 = spark_session.createDataFrame(list2, ('name', 'type'))

        show_df = df1.join(df2, "name", how="outer").select(
            df1.name.alias("hive_column"),
            df1.type.alias("hive_column_type"),
            df2.name.alias("df_column"),
            df2.type.alias("df_column_type"),
        ).cache()
        show_df.show(n=300, truncate=False)

        # 筛选出两表不一致字段展示
        show_df_diff = show_df.filter('hive_column is null or df_column is null')
        show_df_diff.show(n=300, truncate=False)
        # 如果为空说明没有不一致字段,则为true,否则有不一致字段为false
        schema_flag = show_df_diff.count() == 0
        return schema_flag

    @staticmethod
    def check_ods_sync_schema(spark_session: SparkSession, import_table: str, db_type: str, site_name: str,
                              hive_table: str, msg_usr: list = __msg_usr__):
        """
        校验ods层schema是否变动--检查的是ods与hive的schema
        :param spark_session: spark任务对象
        :param import_table:  ods层对应导入表
        :param db_type: ods导入链接类型 mysql / pgsql
        :param site_name: 站点
        :param hive_table: 对应导入的hive ods表
        :param msg_usr: 通知人list--不填写则默认群发
        """
        schema_sql = f"select * from {import_table} limit 0"
        conn_info = DBUtil.get_connection_info(db_type, site_name)
        df_schema = SparkUtil.read_jdbc_query(
            session=spark_session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=schema_sql
        )
        schema_flag = CommonUtil.check_schema(spark_session, df_schema, hive_table)
        # Todo 这里发送消息提醒的验证需要进一步确认
        if not schema_flag:
            msg = f"{hive_table} 与 {import_table} 数据得schema不一致,请查看日志!!  "
            CommonUtil.send_wx_msg(msg_usr, f"\u26A0 {hive_table}同步schema校验异常 \u26A0", msg)
        pass

    @staticmethod
    def check_tb_schema_same(spark_session: SparkSession, tb1_name: str, tb2_name: str):
        """
        检查两个表表结构是不是一样的
        :param tb1_name: 表1
        :param tb2_name: 表2
        :return:
        """
        tb_schema1 = spark_session.sql(f"select * from {tb1_name} limit 0")
        tb_schema2 = spark_session.sql(f"select * from {tb2_name} limit 0")
        list1 = []
        list2 = []
        for i, item in enumerate(tb_schema1.schema.fields):
            list1.append((item.name, item.dataType.simpleString(), i))

        for i, item in enumerate(tb_schema2.schema.fields):
            list2.append((item.name, item.dataType.simpleString(), i))

        df1 = spark_session.createDataFrame(list1, ('name', 'type', "index"))
        df2 = spark_session.createDataFrame(list2, ('name', 'type', "index"))

        show_df = df2.join(df1, "index", how="left").select(
            df2['index'].alias("index"),
            df2.name.alias(f"表2{tb2_name}字段"),
            df2.type.alias(f"表2{tb2_name}类型"),
            df1.name.alias(f"表1{tb1_name}字段"),
            df1.type.alias(f"表1{tb1_name}类型"),
            F.when(df1['name'] == df2['name'], F.lit(1)).otherwise(0).alias("是否一致")
        )

        # 如果最小值返回的为0,则为false:说明有不一致的;如果最小值为1,则为true:说明没有不一致
        schema_flag = bool(show_df.select(F.min("是否一致").alias("result")).first().asDict()['result'])
        if not schema_flag:
            show_df.show(n=300, truncate=False)
        return schema_flag

    @staticmethod
    def check_schema_before_import(db_type: str,
                                   site_name: str,
                                   query: str,
                                   hive_tb_name: str,
                                   msg_usr: list = __msg_usr__,
                                   partition_dict: Dict = None):
        """
        导入前进行原始表数据检查,以及导入query顺序检查
        :param db_type: 原始表db链接类型
        :param site_name: 站点
        :param query: 导入时查询语句
        :param hive_tb_name: 导入的hive表名称
        :param msg_usr: 异常消息通知人
        :param partition_dict: 同步条件dict
        :return: empty_flag、schema_flag
        """
        if partition_dict is not None:
            msg_params = ""
            for key, value in partition_dict.items():
                if value is not None:
                    msg_params += f"{value} "
        else:
            msg_params = ""
        spark_session = SparkUtil.get_spark_session("check_schema")
        rel_query = query.strip()
        rel_query = rel_query.replace(f"and \$CONDITIONS", "")
        if "limit" in rel_query:
            rel_query = rel_query[:rel_query.find("limit")]
        rel_query = f"""{rel_query} limit 1"""

        conn_info = DBUtil.get_connection_info(db_type, site_name)
        import_tb_schema = SparkUtil.read_jdbc_query(
            session=spark_session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=rel_query
        )

        # 如果为空则为true,否则false
        empty_flag = import_tb_schema.count() == 0
        if empty_flag:
            person_in_charge = ",".join(msg_usr)
            msg = f"任务信息:{hive_tb_name} {msg_params}\n负责人:{person_in_charge}"
            msg_usr = msg_usr + ['chenyuanjie', 'chenjianyun', 'leichao', 'chenbo']
            CommonUtil.send_wx_msg(msg_usr, "\u26A0 数据同步异常", msg)
            spark_session.stop()
            raise Exception(msg)

        sql = f"select * from {hive_tb_name} limit 0"
        tb_schema = spark_session.sql(sql)

        list1 = []
        list2 = []
        for i, item in enumerate(tb_schema.schema.fields):
            list1.append((item.name, item.dataType.simpleString(), i))

        for i, item in enumerate(import_tb_schema.schema.fields):
            list2.append((item.name, item.dataType.simpleString(), i))

        df1 = spark_session.createDataFrame(list1, ('name', 'type', "index"))
        df2 = spark_session.createDataFrame(list2, ('name', 'type', "index"))

        show_df = df2.join(df1, "index", how="left").select(
            df2['index'].alias("index"),
            df2.name.alias(f"导入表字段"),
            df2.type.alias(f"导入表类型"),
            df1.name.alias(f"hive表{hive_tb_name}字段"),
            df1.type.alias(f"hive表{hive_tb_name}类型"),
            F.when(df1['name'] == df2['name'], F.lit(1)).otherwise(0).alias("是否一致")
        )
        show_df.show(n=300, truncate=False)

        # 如果最小值返回的为0,则为false:说明有不一致的;如果最小值为1,则为true:说明没有不一致
        schema_flag = bool(show_df.select(F.min("是否一致").alias("result")).first().asDict()['result'])
        if not schema_flag:
            person_in_charge = ",".join(msg_usr)
            msg = f"任务信息:{hive_tb_name} {msg_params}\n负责人:{person_in_charge}"
            msg_usr = msg_usr + ['chenyuanjie', 'chenjianyun', 'leichao', 'chenbo']
            CommonUtil.send_wx_msg(msg_usr, "\u26A0 数据同步异常", msg)
            spark_session.stop()
            raise Exception(msg)
        spark_session.stop()
        return empty_flag, schema_flag

    @staticmethod
    def check_import_sync_num(db_type: str,
                              partition_dict: Dict,
                              import_query: str,
                              hive_tb_name: str,
                              msg_usr: list = __msg_usr__):
        """
        导入前进行原始表数据检查,以及导入query顺序检查
        :param db_type: 原始表db链接类型
        :param partition_dict: 入参dict
        :param import_query: 导入的原始表查询query
        :param hive_tb_name: 导入的hive表名称
        :param msg_usr: 异常消息通知人
        :return:
        """
        spark_session = SparkUtil.get_spark_sessionV3("check_sync_num")
        site_name = partition_dict.get("site_name")
        conn_info = DBUtil.get_connection_info(db_type, site_name)
        # 根据query解析可以获取真实where 条件
        import_query = import_query.replace(f"and \$CONDITIONS", "").strip()
        table_where_query = import_query.split("from")[1]
        select_count_query = "select count(1) as import_total_num from"
        import_count_sql = select_count_query + table_where_query
        print(import_count_sql)
        import_tb_df = SparkUtil.read_jdbc_query(
            session=spark_session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=import_count_sql
        )
        import_tb_count = import_tb_df.collect()[0]['import_total_num']

        # 解析partition_dict获取分区查询条件
        partition_conditions = []
        msg_params = ""
        for key, value in partition_dict.items():
            if value is not None:
                partition_conditions.append(f"{key} = '{value}'")
                msg_params += f"{value} "

        # 拼接分区查询语句
        partition_query = f"SELECT count(1) as hive_total_num  FROM {hive_tb_name}"
        if partition_conditions:
            partition_query += f" WHERE {' AND '.join(partition_conditions)}"

        hive_tb_count = spark_session.sql(partition_query).collect()[0]['hive_total_num']

        # 判断两者数量是否一致
        total_num_flag = bool(import_tb_count == hive_tb_count)
        print(f"import_total_num:{import_tb_count}")
        print(f"{hive_tb_name} total_num:{hive_tb_count}")

        if not total_num_flag:
            person_in_charge = ",".join(msg_usr)
            msg = f"任务信息:{hive_tb_name} {msg_params}\n负责人:{person_in_charge}"
            msg_usr = msg_usr + ['chenyuanjie', 'chenjianyun', 'leichao', 'chenbo']
            CommonUtil.send_wx_msg(msg_usr, "\u26A0 数据同步异常", msg)
            spark_session.stop()
            raise Exception(msg)
        spark_session.stop()

    @staticmethod
    def check_fields_and_warning(hive_tb_name: str, partition_dict: Dict):
        """
        对配置表(hive_field_verify_config) 配置的相应表和相应字段进行校验
        :param hive_tb_name:校验表的表名
        :param partition_dict:校验表的分区条件
        :param msg_usr:异常消息通知人
        :return:
        """
        # 获取计算分区
        msg_params = ""
        for key, value in partition_dict.items():
            if value is not None:
                msg_params += f"{value} "
        base_msg = f"{hive_tb_name} {msg_params} "
        site_name = partition_dict.get("site_name")
        date_type = partition_dict.get("date_type")
        spark_session = SparkUtil.get_spark_sessionV3("check_fields_rule")
        # 获取维护的字段验证配置表数据
        config_table_query = f"""select * from hive_field_verify_config 
                                    where table_name ='{hive_tb_name}' 
                                    and site_name = '{site_name}'
                                    and use_flag = 1 """
        conn_info = DBUtil.get_connection_info('mysql', 'us')
        check_field_df = SparkUtil.read_jdbc_query(
            session=spark_session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=config_table_query
        )
        check_field_list = check_field_df.select('field_name', 'verify_desc', 'verify_type', 'config_json',
                                                 'msg_usr_list').collect()
        if not check_field_list:
            print("============================无验证匹配条件跳过验证===================================")
            return
        for row in check_field_list:
            field_name = row['field_name']
            verify_type = row['verify_type']
            config_json = json.loads(row['config_json'])
            msg_usr = row['msg_usr_list']
            msg_usr_list = [user.strip() for user in msg_usr.split(",")] if msg_usr else []
            if verify_type == "空值率验证":
                query = CommonUtil.generate_null_ratio_query(table_name=hive_tb_name,
                                                             field_name=field_name,
                                                             partition_dict=partition_dict)
                ratio_df = spark_session.sql(query).cache()
                ratio_num = float(ratio_df.collect()[0]['null_ratio'])
                waring_max = float(config_json['max'])
                ratio_df = ratio_df.select(
                    F.col('field_name').alias('校验字段'),
                    F.lit(verify_type).alias('校验类型'),
                    F.col('null_ratio').alias('校验字段空值率'),
                    F.lit(waring_max).alias('空值率阈值'),
                    F.when((F.col('null_ratio') < waring_max), 1).otherwise(0).alias('是否验证通过')
                )
                ratio_df.show(10, truncate=False)

                if ratio_num >= waring_max:
                    # 进行微信推送
                    msg = f"{base_msg} 字段:{field_name}的{verify_type}不通过,请注意该字段的使用!!"
                    CommonUtil.send_wx_msg(msg_usr_list, f"\u26A0 {hive_tb_name} {msg_params}数据{verify_type}异常",
                                           msg)
            elif verify_type == "最大最小值验证":
                query = CommonUtil.generate_min_max_query(table_name=hive_tb_name,
                                                          field_name=field_name,
                                                          partition_dict=partition_dict)
                ratio_df = spark_session.sql(query).cache()
                field_max_vale = float(ratio_df.collect()[0]['max_value'])
                field_min_vale = float(ratio_df.collect()[0]['min_value'])
                waring_max = float(config_json['max'])
                waring_min = float(config_json['min'])
                ratio_df = ratio_df.select(
                    F.col('field_name').alias('校验字段'),
                    F.lit(verify_type).alias('校验类型'),
                    F.col('max_value').alias('校验字段最大值'),
                    F.col('min_value').alias('校验字段最小值'),
                    F.lit(waring_max).alias('最大值上限'),
                    F.lit(waring_min).alias('最小值下限'),
                    F.when((F.col('max_value') <= waring_max) | (F.col('min_value') >= waring_min), 1).otherwise(
                        0).alias('是否验证通过')
                )
                ratio_df.show(10, truncate=False)
                if field_max_vale > waring_max:
                    # 进行微信推送
                    msg = f"{base_msg} 字段:{field_name}的最大值上限验证不通过,请注意该字段的使用!!"
                    CommonUtil.send_wx_msg(msg_usr_list, f"\u26A0 {hive_tb_name} {msg_params}数据{verify_type}异常",
                                           msg)

                if field_min_vale < waring_min:
                    # 进行微信推送
                    msg = f"{base_msg} 字段:{field_name}的最小值下限验证不通过,请注意该字段的使用!!"
                    CommonUtil.send_wx_msg(msg_usr_list, f"\u26A0 {hive_tb_name} {msg_params}数据{verify_type}异常",
                                           msg)
            # elif verify_type == "数据量合法验证":
            #     sql_condition = config_json['sql_condition']
            #     partition_conf_list = config_json['partition_conf']
            #     for conf in partition_conf_list:
            #         conf_site_name = conf["site_name"]
            #         conf_date_type = conf["date_type"]
            #
            #         if site_name == conf_site_name and date_type == conf_date_type:
            #             base_count = conf["base_count"]
            #             break
            #     assert base_count is not None, f"未配置{field_name}验证周期{date_type}的基准值,请检查!"
            #
            #     query = CommonUtil.generate_total_cal_query(table_name=hive_tb_name,
            #                                                 field_name=field_name,
            #                                                 partition_dict=partition_dict,
            #                                                 sql_condition=sql_condition)
            #     ratio_df = spark_session.sql(query).cache()
            #     verify_total_count = int(ratio_df.collect()[0]['verify_total_count'])
            #     waring_max = int(base_count * config_json['max_rate'])
            #     waring_min = int(base_count * config_json['min_rate'])
            #     ratio_df = ratio_df.select(
            #         F.lit(row['verify_desc']).alias('验证描述'),
            #         F.lit(verify_type).alias('验证类型'),
            #         F.col('field_name').alias('校验字段'),
            #         F.col('verify_total_count').alias('校验字段统计值'),
            #         F.lit(waring_max).alias('最大临界值上限'),
            #         F.lit(waring_min).alias('最小临界值下限'),
            #         F.when((F.col('verify_total_count') <= waring_max) | (F.col('verify_total_count') >= waring_min),
            #                F.lit(1)).otherwise(F.lit(0)).alias('是否验证通过')
            #     )
            #
            #     ratio_df.show(10, truncate=False)
            #     if verify_total_count > waring_max:
            #         # 进行微信推送
            #         msg = f"{base_msg} 字段:{field_name}的值{verify_total_count}超出限定最大值:{waring_max},请注意该字段的使用!!"
            #         CommonUtil.send_wx_msg(msg_usr_list, f"\u26A0 {hive_tb_name} {msg_params}数据{verify_type}异常",
            #                                msg)
            #     if verify_total_count < waring_min:
            #         # 进行微信推送
            #         msg = f"{base_msg} 字段:{field_name}的值:{verify_total_count}低于限定最小值:{waring_min},请注意该字段的使用!!"
            #         CommonUtil.send_wx_msg(msg_usr_list, f"\u26A0 {hive_tb_name} {msg_params}数据{verify_type}异常",
            #                                msg)
        pass

    @staticmethod
    def format_df_with_template(spark_session: SparkSession, save_df: DataFrame, save_tb_name: str,
                                roundDouble: bool = False):
        """
        insert into 之前对data_frame 进行自动对齐 及 schema检查
        :param spark_session:
        :param save_df:
        :param save_tb_name:
        :param roundDouble: 是否对double字段进行round截取
        :return:
        """
        sql = f"select * from {save_tb_name} limit 0"
        template_df = spark_session.sql(sql)
        if roundDouble:
            round_val = 4
            for field in save_df.schema.fields:
                if field.dataType == DoubleType():
                    col_name = field.name
                    print(f"{col_name}从{field.dataType}保留小数位数为{round_val}中...")
                    save_df = save_df.withColumn(col_name, F.round(F.col(col_name), round_val))

        return template_df.unionByName(save_df, allowMissingColumns=False)

    @staticmethod
    def auto_transfer_type(spark_session: SparkSession, save_df: DataFrame, hive_tb: str, transfer_dict: Dict = None):
        """
        自动进行类型转换 默认对和hive字段类型不同的进行转换,如果是Double类型则自动转为 DecimalType(10, 3);
        需要特殊处理的传入transfer_dict
        :param spark_session:
        :param save_df:
        :param hive_tb:
        :param transfer_dict:
        :return:
        """
        sql = f"select * from {hive_tb} limit 0"
        tmp_dict = transfer_dict or {}
        tb_schema = spark_session.sql(sql).schema
        for field1 in save_df.schema.fields:
            for field2 in tb_schema.fields:
                col_name = field1.name
                hive_col_name = field2.name
                if col_name == hive_col_name:
                    transfer_flag = (field1.dataType != field2.dataType)
                    transfer_type = field2.dataType
                    if field2.dataType == DoubleType():
                        transfer_type = tmp_dict.get(col_name) or DecimalType(10, 3)
                        transfer_flag = True
                    if transfer_flag:
                        print(f"{col_name}从{field1.dataType}转化为{transfer_type}")
                        save_df = save_df.withColumn(col_name, F.col(col_name).cast(transfer_type))

        return save_df

    @staticmethod
    def select_partitions_df(spark_session: SparkSession, tb_name: str):
        """
        获取表分区df
        """
        df = spark_session.sql(f"show partitions {tb_name}")
        partitions = df.select("partition").rdd.flatMap(lambda x: x).collect()

        values = []
        for index in range(0, len(partitions)):
            item = partitions[index]
            obj = {}
            for sp in item.split("/"):
                val = sp.split("=")
                obj[val[0]] = val[1]
            values.append(obj)

        return spark_session.createDataFrame(values)

    @staticmethod
    def select_col_all(df: DataFrame):
        """
        选择df的所有的列
        """
        return [df[col_name].alias(col_name) for col_name in df.columns]

    @staticmethod
    def df_export_csv(spark_session: SparkSession, export_df: DataFrame, csv_name: str, limit: int = 20 * 10000):
        # output 不进行压缩
        compress_flag = spark_session.conf.get("mapred.output.compress")
        spark_session.sql("set mapred.output.compress=false")
        csv_path = f"/tmp/csv/{csv_name}"
        # 最多导出20w行
        tmp_export_df = export_df.limit(limit)
        tmp_export_df.repartition(1).write.mode("overwrite").option("header", True).csv(csv_path)
        # 合并为一个文件
        client = HdfsUtils.get_hdfs_cilent()
        src_path = list(filter(lambda path: str(path).endswith("csv"), client.list(csv_path)))[0]
        rel_path = f"{csv_path}.csv"
        client.delete(rel_path, True)
        client.rename(f"{csv_path}/{src_path}", rel_path)
        client.delete(csv_path, True)
        print("======================csv生成hdfs文件路径如下======================")
        print(rel_path)
        spark_session.sql(f"set mapred.output.compress={compress_flag}")
        return rel_path

    @classmethod
    def transform_week_tuple(cls, spark_session: SparkSession, date_type: str, date_info: str):
        """
        对周流程进行日期转换,返回日期元祖:如传入month,则返回该月下所有的周
        周流程的week元祖获取
        :param spark_session: spark对象
        :param date_type: 日期类型date_type
        :param date_info: 具体日期date_info
        :return: complete_date_info_tuple: 周数据元祖
        """
        complete_date_info_tuple = tuple()
        df_date = spark_session.sql(f"select * from dim_date_20_to_30 ;")
        df = df_date.toPandas()
        if date_type == 'week':
            complete_date_info_tuple = f"('{date_info}')"
        elif date_type == '4_week':
            print(date_info)
            df_loc = df.loc[(df.year_week == f"{date_info}") & (df.week_day == 1)]
            cur_id = list(df_loc.id)[0]
            df_loc = df.loc[df.id == int(cur_id)]
            week1 = list(df_loc.year_week)[0]
            df_loc = df.loc[df.id == int(cur_id) - 7]
            week2 = list(df_loc.year_week)[0]
            df_loc = df.loc[df.id == int(cur_id) - 14]
            week3 = list(df_loc.year_week)[0]
            df_loc = df.loc[df.id == int(cur_id) - 21]
            week4 = list(df_loc.year_week)[0]
            complete_date_info_tuple = (week1, week2, week3, week4)
        elif date_type == 'month':
            df_loc = df.loc[(df.year_month == f"{date_info}") & (df.week_day == 1)]
            complete_date_info_tuple = tuple(df_loc.year_week)
        print("complete_date_info_tuple:", complete_date_info_tuple)
        return complete_date_info_tuple

    @classmethod
    def create_tmp_tb(cls, spark_session: SparkSession, ddl: str, tb_name: str, drop_exist: bool = False):
        # 默认执行ddl创建表会生成 spark.sql.sources.schema.numParts 语句需要删除
        if drop_exist:
            print(f"drop table {tb_name}")
            spark_session.sql(f"drop table if exists  {tb_name}")
        print(f"创建临时表中:ddl sql 为")
        print(ddl)
        spark_session.sql(ddl)

        sql = f'show tblproperties {tb_name};'
        tblproperties_df = spark_session.sql(sql)
        print(tblproperties_df)
        keys = tblproperties_df.select("key").rdd.flatMap(lambda ele: ele).collect()
        del_key = []
        for key in keys:
            if str(key).startswith("spark.sql.create.version") or str(key).startswith("spark.sql.sources.schema"):
                del_key.append(f"'{key}'")
        if len(del_key) > 0:
            del_sql = f"""alter table {tb_name} unset tblproperties ({",".join(del_key)});"""
            spark_session.sql(del_sql)
        return True

    @classmethod
    def save_or_update_table(cls, spark_session: SparkSession,
                             hive_tb_name: str,
                             partition_dict: Dict,
                             df_save: DataFrame,
                             drop_exist_tmp_flag=True
                             ):
        """
        插入或更新表的分区
        :param spark_session:
        :param hive_tb_name:实际保存表名
        :param partition_dict:
        :param df_save:
        :param drop_exist_tmp_flag:  是否创建表前先删除临时表 如果不删除则在备份表插入分区数据
        :return:
        """
        partition_by = list(partition_dict.keys())
        if HdfsUtils.path_exist(CommonUtil.build_hdfs_path(hive_tb_name, partition_dict)):
            table_copy = f"{hive_tb_name}_copy"
            CommonUtil.create_tmp_tb(
                spark_session,
                ddl=f"""create table if not exists {table_copy} like {hive_tb_name}""",
                tb_name=table_copy,
                drop_exist=drop_exist_tmp_flag
            )

            print(f"当前存储的临时表名为:{table_copy},分区为{partition_by}", )

            if not drop_exist_tmp_flag:
                flag = CommonUtil.check_tb_schema_same(spark_session, tb1_name=hive_tb_name, tb2_name=table_copy)
                if not flag:
                    raise Exception(f"{table_copy}表结构同{hive_tb_name}不一致,交换分区后可能存在错位现象,请检查!!")

            # 插入前先删除copy表的数据再save到cp表
            path = CommonUtil.build_hdfs_path(hive_tb=table_copy, partition_dict=partition_dict, innerFlag=True)
            if HdfsUtils.path_exist(path):
                HdfsUtils.delete_hdfs_file(path)

            df_save.write.saveAsTable(name=table_copy, format='hive', mode='append', partitionBy=partition_by)
            # 交换表名
            CommonUtil.exchange_partition_data(
                spark_session=spark_session,
                tb_src=hive_tb_name,
                partition_dict_src=partition_dict,
                tb_target=table_copy,
                partition_dict_target=partition_dict
            )
        else:
            # 不存在则直接插入
            df_save.write.saveAsTable(name=hive_tb_name, format='hive', mode='append', partitionBy=partition_by)
        print("success")
        pass

    @classmethod
    def exchange_partition_data(cls, spark_session: SparkSession,
                                tb_src: str,
                                partition_dict_src: Dict,
                                tb_target: str,
                                partition_dict_target: Dict,
                                ):
        """
        交换两个分区表数据
        :param spark_session: spark_session
        :param tb_src: 分区表A
        :param partition_dict_src: 分区dict
        :param tb_target:分区表B
        :param partition_dict_target:分区dict
        :return:
        """

        location1: str = spark_session.sql(
            f"""describe formatted {tb_src};"""
        ).where("col_name = 'Location' ").first().asDict().get("data_type").replace("hdfs://nameservice1:8020", "")

        location2: str = spark_session.sql(
            f"""describe formatted {tb_target};"""
        ).where("col_name = 'Location' ").first().asDict().get("data_type").replace("hdfs://nameservice1:8020", "")

        for key in partition_dict_src.keys():
            location1 = location1 + f"/{key}={partition_dict_src.get(key)}"

        for key in partition_dict_target.keys():
            location2 = location2 + f"/{key}={partition_dict_target.get(key)}"

        assert HdfsUtils.path_exist(location1), f"分区【{location1}】不存在!"
        assert HdfsUtils.path_exist(location2), f"分区【{location2}】不存在!"

        HdfsUtils.exchange_path(path_one=location1, path_two=location2)
        return True

    @classmethod
    def get_next_val(cls, date_type: str, date_info: str):
        """
        根据时间类型获取下一个值
        :param date_type:
        :param date_info:
        :return:
        """
        if date_type == DateTypes.day.name:
            result = cls.get_day_offset(date_info, 1)

        elif date_type == DateTypes.week.name:
            engine = DBUtil.get_db_engine("mysql", "us")
            with engine.connect() as connection:
                sql = f"""
                select year_week
                from date_20_to_30
                where year_week > '{date_info}'
                order by year_week
                limit 1  """
                print("================================执行sql================================")
                print(sql)
                result = connection.execute(sql)
                next_week = result.cursor.fetchone()[0]
                result = next_week


        elif date_type == DateTypes.month.name:
            result = cls.get_month_offset(date_info, 1)

        else:
            raise Exception(f"时间类型{date_type}不支持")

        return result

    @classmethod
    def build_ddl_form_df(cls, df: DataFrame, partition_list: list, tb_name: str):
        """
        df 生成 ddl sql
        :param df:
        :param partition_list:
        :param tb_name:
        :return:
        """

        df.schema.fieldNames()
        part = partition_list

        type_dict = {
            DoubleType(): "double",
            DecimalType(): "double",
            StringType(): "string",
            LongType(): "int",
            IntegerType(): "int",
        }
        line1 = []
        line2 = []
        for field in df.schema.fields:
            type = type_dict.get(field.dataType)

            if field.name in part:
                line2.append(f"{field.name}         {type}")
            else:
                line1.append(f"{field.name}         {type}")

        str1 = ",\n".join(line1)
        str2 = ",\n".join(line2)

        ddl = f"""
create table {tb_name}
(
{str1}
)
partitioned by
(
{str2} 
)
row format serde 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
stored as
inputformat 'com.hadoop.mapred.DeprecatedLzoTextInputFormat'
outputformat 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
    """
        return ddl

    @classmethod
    def get_rel_date_type(cls, tb_name: str, date_type: str):
        assert tb_name is not None, "表名不能为空!"
        # 需要特殊处理的就算这些表支持 month_old 其他的不支持
        tmp_list = [
            'dim_st_detail',
            'dwd_st_measure',
            'dwd_st_asin_measure',
            'dwd_asin_measure',
            'dwd_st_volume_fba',
            'dwt_st_market',
            'dws_st_num_stats',
            'dwt_aba_st_analytics'
        ]
        # if date_type in ['month_old'] and date_info < '2023-10':
        if date_type in ['month_old'] and tb_name not in tmp_list:
            return 'month'
        return date_type

    @staticmethod
    def build_hdfs_path(hive_tb: str, partition_dict: Dict = None, innerFlag: bool = False):
        """
        构建对应的表名称
        :param hive_tb:
        :param partition_dict:
        :param innerFlag:
        :return:
        """
        suffix = ""
        if partition_dict is not None:
            tmp = []
            for key in partition_dict.keys():
                tmp.append(f"{key}={partition_dict.get(key)}")
            suffix += "/".join(tmp)

        hdfs_path = None

        if innerFlag:
            if partition_dict is not None:
                hdfs_path = f"/user/hive/warehouse/big_data_selection.db/{hive_tb}/{suffix}"
            else:
                hdfs_path = f"/user/hive/warehouse/big_data_selection.db/{hive_tb}"
        else:
            allow = ['ods', 'dim', 'dwd', 'dws', 'dwt', 'tmp']
            prefix = None
            for tmp in allow:
                if hive_tb.startswith(tmp):
                    prefix = tmp
                pass
            assert prefix is not None, f"{hive_tb}表名不合规,请检查!"
            if partition_dict is not None:
                hdfs_path = f"/home/big_data_selection/{prefix}/{hive_tb}/{suffix}"
            else:
                hdfs_path = f"/home/big_data_selection/{prefix}/{hive_tb}"

        return hdfs_path

    @staticmethod
    def send_wx_msg(users: list, title: str, content: str, msgtype: str = "textcard" ):
        """
        通过选品wx消息推送接口,推送消息到oa
        :param users: 填写需要推送的微信用户名list
        :param title: 推送的标题(如果msgtype采用markdown形式,则不附带标题)
        :param content: 推送的主体内容
        :param msgtype: 推送的消息类型(textcard:默认卡片类型;markdown:markdaown结构)
        """
        if users is not None:
            accounts = ",".join(users)
            # 排除users_list=[''] 无需发送
            if bool(accounts):
                host = "http://120.79.147.190:8080"
                url = f'{host}/soundasia_selection/dolphinScheduler/sendMessage'
                data = {
                    'account': accounts,
                    'title': title,
                    'content': content,
                    'msgtype': msgtype
                }
                try:
                    requests.post(url=url, data=data, timeout=15)
                except:
                    pass
        return True

    @classmethod
    def print_hive_ddl(cls,
                       db_type: str,
                       site_name: str,
                       from_tb: str,
                       hive_tb: str,
                       partition_dict: Dict
                       ):

        engine = DBUtil.get_db_engine(db_type, site_name)

        hdfs_path = cls.build_hdfs_path(hive_tb, partition_dict)
        cols = []
        with engine.connect() as connection:
            sql = f"""
                        select a.attname                                                                           col_name,
                        d.description                                                                              col_desc,
                        concat_ws('', t.typname, SUBSTRING(format_type(a.atttypid, a.atttypmod) from '\(.*\)')) as col_type
                        from pg_class c
                        left join pg_attribute a on a.attrelid = c.oid
                        left join pg_type t on t.oid = a.atttypid
                        left join pg_description d on d.objoid = a.attrelid and d.objsubid = a.attnum
                        where 1 = 1
                        and a.attnum > 0
                        and c.relname in (select tablename from pg_tables where schemaname = 'public')
                        and c.relname = '{from_tb}'
                        and t.typname is not null
                        order by c.relname, a.attnum;
                        """
            for row in list(connection.execute(sql)):
                col_name = row['col_name']
                col_desc = row['col_desc']
                col_type = row['col_type']

                if "int" in col_type:
                    hive_col_type = 'int'
                elif "varchar" in col_type or "text" in col_type:
                    hive_col_type = 'string'
                elif "numeric" in col_type:
                    hive_col_type = 'double'
                elif "float8" in col_type:
                    hive_col_type = 'double'
                elif "date" in col_type:
                    hive_col_type = 'string'
                elif "timestamp" in col_type:
                    hive_col_type = 'string'
                else:
                    hive_col_type = 'string'
                cols.append(f"{col_name}\t{hive_col_type}\tcomment\t'{col_desc}'")
            print("================================执行sql================================")

        partitioned_by = []
        for key in partition_dict.keys():
            partitioned_by.append(f"{key} string comment 'you comment' ")

        col_str = ",\n".join(cols)

        ddl = f"""
        create external table {hive_tb}
        (
            {col_str}
        )
            partitioned by ({",".join(partitioned_by)})
            row format serde 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
            stored as
                inputformat 'com.hadoop.mapred.DeprecatedLzoTextInputFormat'
                outputformat 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
            location 'hdfs://nameservice1:8020{hdfs_path}';
        
        alter table {hive_tb}
            set tblproperties ('comment' = '{hive_tb}表注释');
"""
        print(ddl)
        return ddl

    @staticmethod
    def drop_part(hive_tb: str, partition_dict: Dict):
        """
        删除hive分区数据默认是外部表 不仅仅是删除数据还是删除hive的分区
        :param hive_tb:
        :param partition_dict:
        :return:
        """
        tmparr = []
        for key in partition_dict.keys():
            tmparr.append(f"{key} = '{partition_dict.get(key)}'")
        part_str = ','.join(tmparr)
        hive_ddl = f"""alter table big_data_selection.{hive_tb} drop if exists partition ({part_str});"""
        cmd = rf"""{CommonUtil.__hive_home__} -e "{hive_ddl}" """
        print(f"=============================删除分区中==============================")
        print(cmd)
        client = SSHUtil.get_ssh_client()
        SSHUtil.exec_command_async(client, cmd, ignore_err=False)
        client.close()
        path = CommonUtil.build_hdfs_path(hive_tb=hive_tb, partition_dict=partition_dict)
        print(f"=============================删除分区数据中==============================")
        print(path)
        HdfsUtils.delete_hdfs_file(path)
        pass

    @staticmethod
    def generate_null_ratio_query(table_name: str, field_name: str, partition_dict: Dict):
        """
        构建空值率计算query
        :param table_name: hive表名称
        :param field_name: 需要验证空值率字段
        :param partition_dict:校验表的分区条件
        :return: query: 计算空值率的query
        """

        # 计算空值率sql
        query = f"""SELECT '{field_name}' AS field_name, 
                    COUNT(1) AS total_count, 
                    COUNT(CASE WHEN {field_name} IS NULL THEN 1 WHEN {field_name} = -1 THEN 1 END) AS null_count, 
                    COUNT(CASE WHEN {field_name} IS NOT NULL THEN 1 WHEN {field_name} != -1 THEN 1 END) AS not_null_count, 
                    ROUND(COUNT(CASE WHEN {field_name} IS NULL THEN 1 WHEN {field_name} = -1 THEN 1 END)  / COUNT(1), 4) AS null_ratio 
                    FROM {table_name} """

        # 解析partition_dict获取分区查询条件
        partition_conditions = []
        for key, value in partition_dict.items():
            if value is not None:
                partition_conditions.append(f"{key} = '{value}'")
        # 拼接where条件
        if partition_conditions:
            query += f" WHERE {' AND '.join(partition_conditions)}"
        return query

    @staticmethod
    def generate_min_max_query(table_name: str, field_name: str, partition_dict: Dict):
        """
        构建最大最小值计算query
        :param table_name: hive表名称
        :param field_name: 需要验证空值率字段
        :param partition_dict:校验表的分区条件
        :return: query: 计算空值率的query
        """
        query = f"SELECT '{field_name}' AS field_name, " \
                f"MIN({field_name}) AS min_value, " \
                f"MAX({field_name}) AS max_value " \
                f"FROM {table_name}"

        # 解析partition_dict获取分区查询条件
        partition_conditions = []
        for key, value in partition_dict.items():
            if value is not None:
                partition_conditions.append(f"{key} = '{value}'")
        # 拼接where条件
        if partition_conditions:
            query += f" WHERE {' AND '.join(partition_conditions)}"
        return query

    @staticmethod
    def generate_total_cal_query(table_name: str, field_name: str, partition_dict: Dict, sql_condition: str):
        """
        计算带条件判断的单字段总数query
        :param table_name: hive表名称
        :param field_name: 需要验证空值率字段
        :param partition_dict:校验表的分区条件
        :param sql_condition:其他过滤条件补充
        :return: query: 计算返回使用query
        """
        query = f"SELECT '{field_name}' AS field_name, " \
                f"count({field_name}) AS verify_total_count " \
                f"FROM {table_name}"

        # 解析partition_dict获取分区查询条件
        partition_conditions = []
        for key, value in partition_dict.items():
            if value is not None:
                partition_conditions.append(f"{key} = '{value}'")
        # 拼接where条件
        if partition_conditions:
            query += f" WHERE {' AND '.join(partition_conditions)}"

        # 拼接外部查询条件
        if sql_condition:
            query += f" AND {sql_condition} "

        return query

    @staticmethod
    def judge_is_work_hours(site_name: str = 'us', date_type: str = None, date_info: str = None,
                            principal: str = "wangrui4, huangjian", priority: int = 1, export_tools_type: int = 1,
                            belonging_to_process: str = None):
        """
          导出任务时间约束:控制数据导出任务在非上班时间段进行
        :param site_name: 站点
        :param date_type: 调度类型
        :param date_info: 调度周期
        :param principal: 流程维护人员(与企业微信对应)
        :param priority:  优先级(耗时短的给小的数字,耗时长的给大的数字,非反查搜索词≤3)
        :param export_tools_type: 导出工具(1:sqoop, 2:elasticsearch)
        :param belonging_to_process: 所属流程
        :return:
        """
        exec_env = "/opt/module/anaconda3/envs/pyspark/bin/python3.8"
        # 获取流程的id,看该任务是否流程启动
        process_id = None
        ds_result = DolphinschedulerHelper.get_running_process_task()
        if CommonUtil.notNone(ds_result):
            process_id = str(ds_result.get('process_df_id'))
        #  获取最后一个参数判断是否使用测试导入
        test_flag = CommonUtil.get_sys_arg(len(sys.argv) - 1, None)
        if test_flag == 'test' or process_id is None:
            print("为测试导出或者本地导出,无需监控!")
            return
        # 获取脚本名称
        script_name = sys.argv[0].split("/")[-1]
        print("当前执行的脚本是:" + script_name)
        script_path = ''
        if export_tools_type == 1:
            script_path = "/opt/module/spark/demo/py_demo/sqoop_export/" + script_name
        elif export_tools_type == 2:
            script_path = "/opt/module/spark/demo/py_demo/export_es/" + script_name
        arguments = sys.argv[1:]
        arguments_str = ''
        if len(arguments) > 0:
            arguments_str = ' '.join(arguments)
        # 获取执行命令
        commands = exec_env + " " + script_path + " " + arguments_str
        # 拼接流程命名,进行后期检查导出任务分组
        process_name = process_id if belonging_to_process is None else process_id + "|" + belonging_to_process
        print("执行脚本命令的语句是:" + commands)
        exec_sql = f"""
                    INSERT INTO export_command_records
                    (site_name, date_type, date_info, script_name, commands, status, principal, priority, export_tools_type, belonging_to_process)
                    VALUES
                      ('{site_name}', '{date_type}', '{date_info}', '{script_name}', '{commands}', 1, '{principal}', {priority}, {export_tools_type}, '{process_name}')
                    ON CONFLICT (commands) DO UPDATE
                    SET 
                        site_name = excluded.site_name,
                        date_type = excluded.date_type,
                        date_info = excluded.date_info,
                        script_name = excluded.script_name,
                        status = excluded.status,
                        principal = excluded.principal,
                        priority = excluded.priority,
                        export_tools_type = excluded.export_tools_type,
                        belonging_to_process = excluded.belonging_to_process;
                      """
        print("exec_sql:" + exec_sql)
        DBUtil.exec_sql("postgresql_cluster", "us", exec_sql)
        sys.exit(0)

    @staticmethod
    def modify_export_workflow_status(update_workflow_sql: str, site_name: str = 'us',
                                      date_type: str = None,
                                      date_info: str = None):
        """
        根据流程名称检查导出流程是否完成,更改workflow工作流工具类
        :param update_workflow_sql: 更改workflow工作流的更新语句
        :param site_name: 站点名称
        :param date_type: 日期维度类型
        :param date_info: 日期
        :return:
        """

        #  获取最后一个参数判断是否使用测试导入
        test_flag = CommonUtil.get_sys_arg(len(sys.argv) - 1, None)
        if test_flag == 'test':
            print("测试导出,无需监控!")
            return
        mysql_engine = DBUtil.get_db_engine('mysql', 'us')
        pg_engine = DBUtil.get_db_engine('postgresql_cluster', 'us')
        # 获取脚本名称
        script_name = sys.argv[0].split("/")[-1]
        get_process_sql = f"""select belonging_to_process from export_command_records 
                                        where site_name='{site_name}' 
                                        and date_type = '{date_type}' 
                                        and date_info = '{date_info}' 
                                        and script_name = '{script_name}'
                                        and status != 3
                                          """
        with pg_engine.connect() as connection:
            result = connection.execute(get_process_sql).first()
            belonging_to_process = result['belonging_to_process'] if result else None
        if belonging_to_process is None:
            print("export_command_records库中未记录该流程,无需监控!")
            return
        exec_sql = f"""update export_command_records set status = 3 
                        where script_name='{script_name}' 
                        and belonging_to_process = '{belonging_to_process}'
                                          """
        DBUtil.engine_exec_sql(pg_engine, exec_sql)
        # 检查导出脚本是否都已完成
        check_process_sql = f"""select count(script_name) as uncompleted_num from export_command_records 
                                  where belonging_to_process = '{belonging_to_process}'
                                  and status != 3
                                  """
        with pg_engine.connect() as connection:
            uncompleted_num = connection.execute(check_process_sql).scalar()

        # 看是否导出都已经完成,如果不为3(成功的数量)都为0了,则说明全部导完
        if int(uncompleted_num) == 0:
            print("执行流程更改:" + exec_sql)
            assert update_workflow_sql is not None, "流程更新语句不能为空!!请检查!"
            DBUtil.engine_exec_sql(mysql_engine, update_workflow_sql)
        else:
            print("当前流程下仍有脚本未执行完成!暂未更改流程状态!")
        # 关闭连接
        mysql_engine.dispose()
        pg_engine.dispose()

    @staticmethod
    def judge_not_working_hour():
        """
        判断当前时间是不是用户上班时间
        :return:
        """
        from datetime import datetime
        now = datetime.now()
        hour_minute = CommonUtil.format_now("%H:%M")
        if now.weekday() + 1 in [1, 2, 3, 4, 5]:
            if ('08:20' <= hour_minute <= '12:35') or ('13:30' <= hour_minute <= '18:50'):
                return False
            return True
        else:
            return True

    @staticmethod
    def convert_singular_plural(word):
        """
        单词单复数转换,与java一致的转换词库textblob库支持
        :param word: 需要进行单复数转换的词
        :return:convert_word 根据传入的单词word解析单复数词性后,返回转换词
        """
        if not word:
            return None

        word_object = Word(word)
        singular_form = word_object.singularize()
        plural_form = word_object.pluralize()
        # 判断word到底原始词性是单数还是复数,决定返回转换后的值
        convert_word = plural_form if word == singular_form else singular_form
        return convert_word

    @staticmethod
    def list_build_sqlin_condition(word_list):
        """
        将list转换成 where condition in (a,x,c) 子条件语句
        :param word_list: 需要组装成in语句的列表
        :return: condition: 组装好后的 in的子条件语句如 (a,b,x)
        """
        # 检查列表长度
        if len(word_list) == 1:
            condition = f"'({word_list[0]})'"
        else:
            condition = tuple(word_list)
        return condition

    @staticmethod
    def get_start_name_from_hdfs_path(hdfs_path: str):
        locate_index = ['ods', 'dim', 'dwd', 'dws', 'dwt', 'tmp']
        hidden_param = ['site_name', 'date_type', 'date_info']
        for locate_word in locate_index:
            word_index = hdfs_path.find(locate_word)
            if word_index != -1:
                break
        if word_index != -1:
            content_extract = hdfs_path[word_index + 4:]  # 4 是 "xxx/" 的长度
            content_extract = content_extract.replace("/", ":")
            for hidden_word in hidden_param:
                content_extract = content_extract.replace(hidden_word + "=", "")
            return content_extract
        else:
            return None

    @classmethod
    def get_asin_variant_attribute(cls, df_asin_detail: DataFrame, df_asin_measure: DataFrame, partition_num: int=80, use_type: int=0):
        """
        Param df_asin_detail: asin详情DataFrame(
            字段要求:
                必须有asin,
                asin_vartion_list(Kafka中有,schema参照:StructField("asin_vartion_list", ArrayType(ArrayType(StringType()), True), True)),
                buy_sales(kafka中有,schema参照: StructField("buy_sales", StringType(), True))

        );
        Param df_asin_measure: asin度量信息DataFrame(
            字段要求:
                必须有asin、asin_zr_counts, asin_adv_counts, asin_st_counts, asin_amazon_orders,
                asin_zr_flow_proportion, asin_ao_val
        );
        Param partition_num: 运行并行度(根据脚本运行资源设置)
        Param use_type: 使用类型(0:默认,插件; 1:流量选品)
        return :
                1. dwd_asin_measure必须携带的:
                    asin、asin_zr_counts, asin_adv_counts, asin_st_counts, asin_amazon_orders, asin_zr_flow_proportion, asin_ao_val
                2. 处理得到的:
                    matrix_ao_val, matrix_flow_proportion, asin_amazon_orders, variant_info(变体asin列表)
                3. 流量选品特定的: color, size, style
                4. dwd_asin_measure自行携带的字段
        """
        # 1.关联获取ao、各类型数量、流量占比信息、月销信息等
        df_asin_detail = df_asin_detail.repartition(partition_num)
        df_asin_measure = df_asin_measure.repartition(partition_num)
        df_asin_detail = df_asin_detail.join(
            df_asin_measure, on=['asin'], how='left'
        )
        # 2.解析亚马逊月销信息
        df_asin_detail = df_asin_detail.withColumn(
            "bought_month",
            F.when(F.col("buy_sales").isNotNull(), cls.u_parse_amazon_orders(F.col("buy_sales"))).otherwise(F.lit(None))
        )
        df_asin_detail = df_asin_detail.withColumn("asin_amazon_orders", F.coalesce(F.col("bought_month"), F.col("asin_amazon_orders"))).drop("bought_month")
        # 3.统计母体ao和流量占比
        df_with_variant_attribute = df_asin_detail.filter(F.expr("size(asin_vartion_list) > 0"))
        df_explode_variant_attribute = df_with_variant_attribute.select(
            "asin", F.explode("asin_vartion_list").alias("variant_attribute")
        ).select(
            "asin", F.col("variant_attribute")[0].alias("variant_asin"), F.col("variant_attribute")[1].alias("color"),
            F.col("variant_attribute")[3].alias("size"), F.col("variant_attribute")[5].alias("style")
        )
        df_variant_asin_detail = df_asin_measure.select(F.col("asin").alias("variant_asin"), "asin_zr_counts", "asin_adv_counts", "asin_st_counts")
        df_explode_variant_attribute = df_explode_variant_attribute.repartition(partition_num)
        df_explode_variant_attribute_detail = df_explode_variant_attribute.join(
            df_variant_asin_detail, on=["variant_asin"], how="inner"
        )
        df_explode_variant_attribute_agg = df_explode_variant_attribute_detail.groupby(['asin']).agg(
            F.sum("asin_zr_counts").alias("sum_zr_counts"),
            F.sum("asin_adv_counts").alias("sum_adv_counts"),
            F.sum("asin_st_counts").alias("sum_st_counts"),
            F.collect_set(F.col("variant_asin")).alias("variant_info")
        )
        df_explode_variant_attribute_agg = df_explode_variant_attribute_agg.repartition(partition_num)
        df_explode_variant_attribute_agg = df_explode_variant_attribute_agg.withColumn(
            "matrix_flow_proportion",
            F.when(F.col("sum_st_counts").isNotNull(), F.round(F.col("sum_zr_counts") / F.col("sum_st_counts"), 4))
        ).withColumn(
            "matrix_ao_val",
            F.when(F.col("sum_zr_counts").isNotNull(), F.round(F.col("sum_adv_counts") / F.col("sum_zr_counts"), 3))
        ).drop("sum_zr_counts", "sum_adv_counts", "sum_st_counts")
        df_asin_detail = df_asin_detail.join(
            df_explode_variant_attribute_agg, on=['asin'], how='left'
        )
        df_asin_detail = df_asin_detail.withColumn(
            "matrix_ao_val", F.coalesce(F.col("matrix_ao_val"), F.col("asin_ao_val"))
        ).withColumn(
            "matrix_flow_proportion", F.coalesce(F.col("matrix_flow_proportion"), F.col("asin_zr_flow_proportion"))
        )
        # 4.解析变体属性信息(颜色、 尺寸、 风格等)
        if use_type == 1:
            df_asin_attribute = df_explode_variant_attribute.filter(F.col("asin") == F.col("variant_asin")).drop("variant_asin")
            df_asin_detail = df_asin_detail.join(
                df_asin_attribute, on=['asin'], how='left'
            )
        return df_asin_detail

    @staticmethod
    def unified_variant_asin_basic_detail(df_asin_detail: DataFrame, columns_list: list, partition_num: int=80, use_type: int=0):
        """
            Param: df_asin_detail   每批次ASIN详情数据;
            Param: columns_list     变体ASIN间共用属性字段(根据场景传入指定的字段);
            Param: partition_num    分区数(根据任务运行申请的资源配置)
            Parm:  use_type         使用场景:0:插件,1:流量选品
            Return: df_asin_detail  ASIN详情数据
                    df_latest_asin_detail_with_parent   每一批ASIN详情数据中最新的变体ASIN公用属性信息
        """
        if use_type == 0:
            df_asin_detail = df_asin_detail.withColumnRenamed("parentAsin", "parent_asin")
        df_asin_detail = df_asin_detail.repartition(partition_num)
        # 将公用属性字段切换名称,防止冲突
        renamed_columns = [F.col(c).alias(f"new_{c}") for c in columns_list]
        df_with_parent_asin = df_asin_detail.filter("parent_asin is not null").select("parent_asin", "asinUpdateTime", *renamed_columns)
        # 获取每一批ASIN详情数据中有parent_asin信息且最新爬取的ASIN详情作为共用属性
        parent_asin_window = Window.partitionBy("parent_asin").orderBy(F.desc_nulls_last("asinUpdateTime"))
        df_with_parent_asin = df_with_parent_asin.withColumn("ct_rank", F.row_number().over(window=parent_asin_window))
        df_with_parent_asin = df_with_parent_asin.repartition(partition_num)
        df_latest_asin_detail_with_parent = df_with_parent_asin.filter("ct_rank = 1").drop("ct_rank", "asinUpdateTime")
        # 将每一批ASIN详情数据中相同变体ASIN之间属性统一
        df_asin_detail = df_asin_detail.join(df_latest_asin_detail_with_parent, on=['parent_asin'], how='left')
        # 丢弃原有属性字段,使用统一后的属性
        for column in columns_list:
            df_asin_detail = df_asin_detail.withColumn(column, F.coalesce(F.col(f"new_{column}"), F.col(column))).drop(f"new_{column}")
            df_latest_asin_detail_with_parent = df_latest_asin_detail_with_parent.withColumnRenamed(f"new_{column}", f"{column}")
        return df_asin_detail, df_latest_asin_detail_with_parent