import os
import sys
import json

sys.path.append(os.path.dirname(sys.path[0]))
from utils.ssh_util import SSHUtil
from utils.common_util import CommonUtil
from utils.db_util import DBUtil
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql import functions as F


def vertify_data(hive_table,partition_dict):
    # 获取计算分区
    msg_params = ""
    # 解析partition_dict获取分区查询条件
    partition_conditions = []
    for key, value in partition_dict.items():
        if value is not None:
            msg_params += f"{value} "
            partition_conditions.append(f"{key} = '{value}'")
    base_msg = f"{hive_table} {msg_params} "
    site_name = partition_dict.get("site_name")
    spark_session = SparkUtil.get_spark_sessionV3("check_fields_rule")
    # 获取维护的字段验证配置表数据
    config_table_query = f"""select * from hive_field_verify_config 
                                            where table_name ='{hive_table}' 
                                            and site_name = '{site_name}'
                                            and use_flag = 1 """
    conn_info = DBUtil.get_connection_info('mysql', 'us')
    check_field_df = SparkUtil.read_jdbc_query(
        session=spark_session,
        url=conn_info["url"],
        pwd=conn_info["pwd"],
        username=conn_info["username"],
        query=config_table_query
    )
    # 获取验证消息
    check_field_list = check_field_df.select('field_name', 'verify_desc', 'verify_type', 'config_json',
                                             'msg_usr_list').collect()
    if not check_field_list:
        print("============================无验证匹配条件跳过验证===================================")
        exit()
    # 创建一个df用于储存验证情况
    # 定义列的结构
    schema = StructType([
        StructField("验证描述", StringType(), True),
        StructField("验证类型", StringType(), True),
        StructField("校验字段", StringType(), True),
        StructField("校验条件查询数值", StringType(), True),
        StructField("验证临界值", StringType(), True),
        StructField("是否验证通过", IntegerType(), True),

    ])

    # 使用定义的结构创建空的 DataFrame
    check_df = spark_session.createDataFrame([], schema)


    for row in check_field_list:
        vertify_flag = True
        field_name = row['field_name']
        verify_type = row['verify_type']
        config_json = json.loads(row['config_json'])
        msg_usr = row['msg_usr_list']
        msg_usr_list = [user.strip() for user in msg_usr.split(",")] if msg_usr else []
        partition_conf_list = config_json['partition_conf']

        for conf in partition_conf_list:
            conf_site_name = conf["site_name"]
            conf_date_type = conf["date_type"]
            if site_name == conf_site_name and date_type == conf_date_type:
                vertify_flag = True
                break
            else:
                vertify_flag = False
        # assert base_rate is not None, f"未配置{field_name}验证周期{date_type}的基准值，请检查！"

        # 没有合适的匹配维度
        if not vertify_flag:
            break

        if verify_type == "自定义sql验证":
            base_num = conf['max_rate']
            confirm_sql = str(config_json['confirm_sql'])
            base_condition = ' AND '.join(partition_conditions)
            # 需把sql语句中的base_condition用时间周期的语句进行替换
            confirm_sql = confirm_sql.replace("base_condition", base_condition)
            confirm_df = spark_session.sql(confirm_sql)
            confirm_row = confirm_df.collect()[0]
            # 提取自定义sql中的验证结果
            confirm_num = confirm_row["confirm_num"]
            confirm_result = confirm_row["confirm_result"]
            result_df = spark_session.createDataFrame(
                [(row['verify_desc'], verify_type, field_name, confirm_num, base_num, confirm_result)],
                schema).repartition(1)

        elif verify_type == "分类销量最大排名验证":
            sql_condition = config_json['sql_condition']
            base_num = conf['max_rate']
            confirm_sql = CommonUtil.generate_min_max_query(hive_table, field_name, partition_dict)
            # 拼接外部查询条件
            if sql_condition:
                confirm_sql = confirm_sql + f" AND {sql_condition} "
            confirm_df = spark_session.sql(confirm_sql)
            confirm_row = confirm_df.collect()[0]
            confirm_num = confirm_row["max_value"]
            if confirm_num:
                # 必须要大于该基准值才校验通过
                confirm_result = 1 if (confirm_num >= base_num) else 0
            else:
                confirm_result = 0
            result_df = spark_session.createDataFrame(
                [(row['verify_desc'], verify_type, field_name, confirm_num, base_num, confirm_result)],
                schema).repartition(1)
        check_df = check_df.unionByName(result_df, False)

    if check_df.count() < 1:
        print("无验证项验证")
        exit()
    check_df.show(50, truncate=False)
    schema_flag = bool(check_df.select(F.min("是否验证通过").alias("result")).first().asDict()['result'])
    # print(schema_flag)
    if not schema_flag:
        msg = f"数据表：{hive_table} {msg_params}，计算数据存在验证不通过，请检查数据是否异常！！具体信息请查看日志！！"
        CommonUtil.send_wx_msg(['fangxingjun','pengyanbing','chenjianyun'], f"\u26A0 {hive_table} {msg_params}数据导入验证异常", msg)
        raise Exception(msg)
        spark_session.stop()
    pass


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    assert site_name is not None, "site_name 不能为空！"
    assert date_type is not None, "date_type 不能为空！"
    assert date_info is not None, "date_info 不能为空！"
    hive_tb = "ods_one_category_report"
    db_type = "mysql"

    assert date_type in ('week','month','month_week'), "入参date_type类型存在问题，请检查！"

    # 该表为月表，因此如果传入week周期进行判断，获取周对应的月维度
    if (date_type == 'week') and (date_info is not None):
        engine = DBUtil.get_db_engine('mysql', 'us')
        sql = f"""select `year_month` from date_20_to_30 where `year_week`='{date_info}' and week_day = 1 """
        result = DBUtil.engine_exec_sql(engine, sql)
        year_month = result.scalar()
        print(f"当前传入的周期为周维度，date_type：{date_type},date_info:{date_info};对应转换月为：{year_month}")
        engine.dispose()
        date_type = 'month'
        date_info = year_month

    year, month = date_info.split("-")


    partition_dict = {
        "site_name": site_name,
        "date_type": date_type,
        "date_info": date_info
    }

    hdfs_path = CommonUtil.build_hdfs_path(hive_tb, partition_dict=partition_dict)
    print(f"hdfs_path is {hdfs_path}")

    import_tb = f"{site_name}_one_category_report"
    cols = "id,cate_1_id,name,rank,orders,orders_day,`year_month`,week,created_at,updated_at,category_id"

    query = f"""
       select 
       {cols}
        from {import_tb}
        where `year_month` = '{year}_{int(month)}'
        and \$CONDITIONS
"""
    print(query)
    empty_flag, check_flag = CommonUtil.check_schema_before_import(db_type=db_type,
                                                                   site_name=site_name,
                                                                   query=query,
                                                                   hive_tb_name=hive_tb,
                                                                   msg_usr=['fangxingjun','pengyanbin']
                                                                   )
    assert check_flag, f"导入hive表{hive_tb}表结构检查失败！请检查query是否异常！！"

    if not empty_flag:
        sh = CommonUtil.build_import_sh(site_name=site_name,
                                        db_type=db_type,
                                        query=query,
                                        hdfs_path=hdfs_path)
        # 导入前先删除
        HdfsUtils.delete_hdfs_file(hdfs_path)
        client = SSHUtil.get_ssh_client()
        SSHUtil.exec_command_async(client, sh, ignore_err=False)
        CommonUtil.after_import(hdfs_path=hdfs_path, hive_tb=hive_tb)
        client.close()

        # 导入后检测--检测数据一致性
        CommonUtil.check_import_sync_num(db_type=db_type,
                                         partition_dict=partition_dict,
                                         import_query=query,
                                         hive_tb_name=hive_tb,
                                         msg_usr=['fangxingjun','pengyanbin']
                                         )
        vertify_data(hive_table=hive_tb, partition_dict=partition_dict)

    pass
