import os
import sys
import re

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql.types import StringType, MapType, IntegerType, ArrayType
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from utils.db_util import DBUtil
from pyspark.sql import functions as F
from yswg_utils.common_udf import udf_handle_string_null_value
from utils.templates import Templates


class DwsStTheme(Templates):
    def __init__(self, site_name, date_type, date_info):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name},{date_type},{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.db_save = "dws_st_theme"
        self.partition_dict = {
            "site_name": site_name,
            "date_type": date_type,
            "date_info": date_info
        }
        hdfs_path = CommonUtil.build_hdfs_path(self.db_save, partition_dict=self.partition_dict)
        print(f"清除hdfs目录中:{hdfs_path}")
        HdfsUtils.delete_file_in_folder(hdfs_path)
        self.partitions_by = ["site_name", "date_type", "date_info"]
        self.reset_partitions(1)

        # 全局df初始化
        self.df_st_asin_info = self.spark.sql(f"select 1+1;")
        self.df_st_detail = self.spark.sql(f"select 1+1;")
        self.df_st_key = self.spark.sql(f"select 1+1;")
        self.df_theme = self.spark.sql(f"select 1+1;")
        self.df_st_base = self.spark.sql(f"select 1+1;")
        self.df_st_theme = self.spark.sql(f"select 1+1;")
        self.df_st_theme_vertical = self.spark.sql(f"select 1+1;")
        self.df_st_save = self.spark.sql(f"select 1+1;")
        self.df_st_topic_base = self.spark.sql(f"select 1+1;")
        self.df_st_match_topic_detail = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.topic_rules_regexp_dict = self.create_regexp_rules()

        # 注册自定义函数 (UDF)
        self.u_theme_contain_judge = F.udf(self.u_theme_contain_judge, IntegerType())
        self.u_handle_string_null_value = F.udf(udf_handle_string_null_value, StringType())

    @staticmethod
    def parse_ele_match_regexp(pattern):
        def udf_ele_mattch(match_text):
            ele_list = re.findall(pattern, match_text)
            if ele_list:
                return ','.join(set(ele_list))
            else:
                return None

        return F.udf(udf_ele_mattch, StringType())

    # 通过主题正则关系表拼接正则规范--汪瑞
    @staticmethod
    def create_regexp_rules():
        theme_rules_regexp_dict = {}
        engine = DBUtil.get_db_engine('mysql', 'us')
        rules_sql = f"""
                       select theme_ch,regular_expression_type, GROUP_CONCAT(label_en_lower) as key_list from aba_match_theme_rules group by theme_ch, regular_expression_type;
                   """
        with engine.connect() as connect:
            rules_result = connect.execute(rules_sql)
            for rules_row in rules_result:
                if rules_row.regular_expression_type == 0:
                    regexp_str = r"(?<!\+|\*|\-|\%|\.|\')\b(" + "|".join(str(rules_row.key_list).split(',')) + r")\b"
                elif rules_row.regular_expression_type == 1:
                    regexp_str = r"(\d+(?:\.\d+)?) +(" + "|".join(str(rules_row.key_list).split(',')) + r")\b"
                elif rules_row.regular_expression_type == 2:
                    regexp_str = r"\b(" + "|".join(str(rules_row.key_list).split(',')) + r") +(\d+(?:\.\d+)?)"
                elif rules_row.regular_expression_type == 3:
                    regexp_str = r"(\d+(?:\.\d+)?)(" + "|".join(str(rules_row.key_list).split(',')) + r")\b"
                elif rules_row.regular_expression_type == 4:
                    regexp_str = r"(\d+(?:\.\d+)?) *(" + "|".join(
                        str(rules_row.key_list).split(',')) + r") *(\d+(?:\.\d+)?)(?: |$)"
                elif rules_row.regular_expression_type == 5:
                    regexp_str = r"(\d+(?:\.\d+)?) *(-) *(\d+(?:\.\d+)?) *(" + "|".join(
                        str(rules_row.key_list).split(',')) + r")\b"
                else:
                    regexp_str = r"\b(" + "|".join(
                        str(rules_row.key_list).split(',')) + r") +(\d+(?:\.\d+)?) *(-) *(\d+(?:\.\d+)?)"
                if rules_row.theme_ch in theme_rules_regexp_dict:
                    rules_list = theme_rules_regexp_dict.get(rules_row.theme_ch)
                    rules_list.append(regexp_str)
                    theme_rules_regexp_dict[rules_row.theme_ch] = rules_list
                else:
                    rules_list = []
                    rules_list.append(regexp_str)
                    theme_rules_regexp_dict[rules_row.theme_ch] = rules_list
            rules_list = theme_rules_regexp_dict['尺寸']
            rules_list.append(r'\b(\d+) *(\'\')')
        connect.close()
        return theme_rules_regexp_dict

    @staticmethod
    def parse_search_term_theme(topic_rules_regexp_dict):
        def parse_theme(search_term):
            if len(topic_rules_regexp_dict) > 0:
                matches_map = {}
                for theme_rule in topic_rules_regexp_dict.keys():
                    matches_list = []
                    rules_regexp_list = topic_rules_regexp_dict[theme_rule]
                    for rules_regexp in rules_regexp_list:
                        matches = re.findall(rules_regexp, search_term, re.IGNORECASE)
                        if len(matches) > 0:
                            matches_list.append(matches[0])
                    if len(matches_list) > 0:
                        matches_map[theme_rule] = matches_list
                if len(matches_map) > 0:
                    theme_array = []
                    for theme_rule in matches_map.keys():
                        matches_list = matches_map[theme_rule]
                        for target_matches in matches_list:
                            num_info = ""
                            unit_info = ""
                            num_pattern = r'^\d+(\.\d+)?$'
                            # target_matches = max(matches_list, key=len)
                            if isinstance(target_matches, tuple):
                                match_low = [match.lower() for match in target_matches]
                            else:
                                match_low = target_matches
                            if isinstance(match_low, str):
                                unit_info = match_low
                            else:
                                if len(match_low) == 2:
                                    if re.match(num_pattern, str(match_low[0])):
                                        num_info = match_low[0]
                                        unit_info = match_low[1]
                                    else:
                                        num_info = match_low[1]
                                        unit_info = match_low[0]
                                elif len(match_low) == 3:
                                    num_info = "".join(match_low)
                                    unit_info = match_low[1]
                                elif len(match_low) == 4:
                                    if re.match(num_pattern, str(match_low[0])):
                                        num_info = "".join(match_low[0:3])
                                        unit_info = match_low[3]
                                    else:
                                        unit_info = match_low[0]
                                        num_info = "".join(match_low[1:4])
                            theme_array.append(
                                {"theme": str(theme_rule), "num_info": str(num_info), "unit_info": str(unit_info)})
                    return theme_array

        return F.udf(parse_theme, ArrayType(MapType(StringType(), StringType())))

    @staticmethod
    def u_theme_contain_judge(pattern_word, pattern_list):
        num_pattern = r'^\d'
        pattern_flag = bool(re.match(num_pattern, pattern_word))
        count = 0
        if not pattern_flag:
            count = sum(1 for word in pattern_list if re.search(r'\b{}\b'.format(re.escape(pattern_word)), word))
        # 如果匹配到的pattern_word大于1则说明有已经匹配过的单词
        return 0 if count > 1 else 1

    def read_data(self):
        sql1 = f"""
            select 
                search_term
            from dim_st_asin_info
            where site_name = '{site_name}'
            and date_type = '{date_type}'
            and date_info = '{date_info}';
        """
        print("sql:", sql1)
        self.df_st_asin_info = self.spark.sql(sql1).repartition(40, 'search_term').cache()

        sql2 = f"""
            select 
                search_term
            from dim_st_detail
            where site_name = '{site_name}'
            and date_type = '{date_type}'
            and date_info = '{date_info}';
        """
        print("sql:", sql2)
        self.df_st_detail = self.spark.sql(sql2).repartition(40, 'search_term').cache()

        sql3 = f"""
            select 
                st_key,
                search_term
            from ods_st_key
            where site_name = '{site_name}';
        """
        print("sql:", sql3)
        self.df_st_key = self.spark.sql(sql3).repartition(40, 'search_term').cache()

        # 获取主题词
        sql4 = f"""
            select 
                theme_en,
                theme_ch,
                label_ch,
                label_en_lower 
            from selection.aba_match_theme 
            where label_ch is not null
        """
        print("sql:", sql4)
        conn_info = DBUtil.get_connection_info("mysql", "us")
        self.df_theme = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=sql4
        ).cache()

    def handle_data(self):
        self.handle_base()
        self.handle_st_theme()
        self.match_st_topic()
        self.match_search_term_ch_unit()
        self.handle_contains_theme()
        self.handle_save()

    # 去重处理
    def handle_base(self):
        self.df_st_base = self.df_st_asin_info.unionByName(
            self.df_st_detail
        ).drop_duplicates(['search_term'])
        self.df_st_base = self.df_st_base.join(
            self.df_st_key, on='search_term', how='inner'
        ).cache()

    # 给每个搜索词打上主题标签
    def handle_st_theme(self):
        pdf_theme = self.df_theme.toPandas()
        theme_list = list(set(pdf_theme.label_en_lower))
        pattern = re.compile(r'(?<!\+|\*|\-|\%|\.|\')\b({})\b'.format('|'.join([re.escape(x) for x in theme_list])),
                             flags=re.IGNORECASE)
        self.df_st_theme = self.df_st_base.withColumn(
            "theme_en_pattern_str",
            self.parse_ele_match_regexp(pattern)(F.col("search_term"))
        )
        # 将匹配到的字符串拆分成list
        self.df_st_theme = self.df_st_theme.withColumn(
            "label_en_lower_list",
            F.split(F.col("theme_en_pattern_str"), ",")
        )
        # 对list进行explode炸裂,转换成多行数据
        self.df_st_theme = self.df_st_theme.withColumn(
            "label_en_lower",
            F.explode(F.col("label_en_lower_list"))
        )
        self.df_st_theme = self.df_st_theme.select(
            'st_key', 'search_term', 'label_en_lower'
        )
        # 进行主题补回,根据匹配词(label_en_lower)关联
        self.df_st_theme = self.df_st_theme.join(
            self.df_theme, on=['label_en_lower'], how='left'
        )

        self.df_st_theme_vertical = self.df_st_theme.drop_duplicates(['st_key', 'search_term', 'theme_en', 'label_ch'])
        self.df_st_theme_vertical = self.df_st_theme_vertical.filter('theme_en is not null')
        self.df_st_theme_vertical = self.df_st_theme_vertical.select(
            F.col('st_key'),
            F.col('search_term'),
            F.col('theme_ch'),
            F.col('theme_en'),
            F.col('label_ch').alias('theme_label_ch'),
            F.col('label_en_lower').alias('theme_label_en'),
            F.lit(0).alias('pattern_type'),
            F.lit(None).alias('theme_label_num_info'),
            F.lit(None).alias('theme_label_unit_info')
        )

    # 通过正则规则匹配搜索词
    def match_st_topic(self):
        self.df_st_topic_base = self.df_st_base.withColumn(
            "all_theme_field",
            F.explode(self.parse_search_term_theme(self.topic_rules_regexp_dict)(F.col("search_term")))
        ).withColumn(
            "theme",
            F.col("all_theme_field").getItem("theme")
        ).withColumn(
            "num_info",
            F.col("all_theme_field").getItem("num_info")
        ).withColumn(
            "unit_info",
            F.col("all_theme_field").getItem("unit_info")
        ).drop("all_theme_field")

    # 通过正则匹配结果拿回中文单位信息
    def match_search_term_ch_unit(self):
        df_st_mate = self.df_st_topic_base.filter(
            ~(F.col("num_info").isNull() & (F.col("num_info") == "") &
              F.col("unit_info").isNull() & (F.col("unit_info") == ""))
        )
        df_st_no_num = df_st_mate.filter(F.col("num_info") == '')
        df_st_with_num = df_st_mate.exceptAll(df_st_no_num)
        key_word_in_theme_info_sql = f"""
            select
            theme_ch as theme, 
            theme_en, 
            label_en_lower as unit_info,
            label_ch 
            from aba_match_theme_rules 
            where regular_expression_type = 0
        """
        key_word_with_num_in_theme_info_sql = f"""
            select 
            theme_ch as theme, 
            theme_en, 
            label_en_lower as unit_info,
            label_ch 
            from aba_match_theme_rules 
            where regular_expression_type != 0 
            group by theme_ch, theme_en, label_en_lower, label_ch
        """
        con_info = DBUtil.get_connection_info('mysql', 'us')
        if con_info is not None:
            df_key_word_in_theme = SparkUtil.read_jdbc_query(
                session=self.spark, url=con_info['url'],
                pwd=con_info['pwd'],
                username=con_info['username'],
                query=key_word_in_theme_info_sql
            )
            df_key_word_with_num_in_theme = SparkUtil.read_jdbc_query(
                session=self.spark, url=con_info['url'],
                pwd=con_info['pwd'],
                username=con_info['username'],
                query=key_word_with_num_in_theme_info_sql
            )
            df_st_no_num = df_st_no_num.join(
                df_key_word_in_theme, how='inner', on=['theme', 'unit_info']
            )
            df_st_with_num = df_st_with_num.join(
                df_key_word_with_num_in_theme, how='inner', on=['theme', 'unit_info']
            )
            self.df_st_match_topic_detail = df_st_no_num.unionByName(df_st_with_num)

            self.df_st_match_topic_detail = self.df_st_match_topic_detail.select(
                F.col('st_key'),
                F.col('search_term'),
                F.col('theme').alias('theme_ch'),
                F.col('theme_en'),
                F.col('label_ch').alias('theme_label_ch'),
                F.when(
                    F.col('num_info') == '', F.col('unit_info')
                ).when(
                    (F.col('unit_info') == 'x') | (F.col('unit_info') == 'by'), F.col('num_info')
                ).when(
                    (F.col("num_info") != '') & (F.col("unit_info") != '') &
                    (F.col("unit_info") != 'x') & (F.col("unit_info") != 'by'),
                    F.concat_ws(' ', F.col("num_info"), F.col("unit_info"))
                ).alias('theme_label_en'),
                F.lit(1).alias('pattern_type'),
                F.col('num_info').alias('theme_label_num_info'),
                F.col('unit_info').alias('theme_label_unit_info')
            )

    def handle_contains_theme(self):
        # 处理ab词包含关系的匹配
        df_st_join = self.df_st_theme_vertical.unionByName(self.df_st_match_topic_detail)
        df_st_join = df_st_join.filter(' theme_label_en is not null ')
        # 处理AB匹配词
        df_st_pattern_list = df_st_join.groupBy(['st_key', 'search_term']).agg(
            F.collect_list("theme_label_en").alias("pattern_label_list")
        )
        df_st_join = df_st_join.join(
            df_st_pattern_list, on=['st_key', 'search_term'], how='left'
        )

        # 自定义方法判断是否包含多个匹配词,打上标记标签
        df_st_join = df_st_join.withColumn(
            'pattern_flag',
            self.u_theme_contain_judge(F.col('theme_label_en'), F.col('pattern_label_list'))
        ).drop('pattern_label_list')

        # 过滤掉为0的数据(即已经被ab词匹配过的那些a词和b词)
        self.df_save = df_st_join.filter('pattern_flag = 1')

    def handle_save(self):
        self.df_save = self.df_save.select(
            F.col('st_key'),
            F.col('search_term'),
            F.col('theme_ch'),
            F.col('theme_en'),
            F.col('theme_label_ch'),
            F.col('theme_label_en'),
            F.date_format(F.current_timestamp(), 'yyyy-MM-dd HH:mm:SS').alias('created_time'),
            F.date_format(F.current_timestamp(), 'yyyy-MM-dd HH:mm:SS').alias('updated_time'),
            F.col('pattern_type'),
            self.u_handle_string_null_value(F.col('theme_label_num_info')).alias('theme_label_num_info'),
            F.col('theme_label_unit_info'),
            F.lit(self.site_name).alias('site_name'),
            F.lit(self.date_type).alias('date_type'),
            F.lit(self.date_info).alias('date_info')
        )


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    assert site_name is not None, "site_name 不能为空!"
    assert date_type is not None, "date_type 不能为空!"
    assert date_info is not None, "date_info 不能为空!"
    handle_obj = DwsStTheme(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()