import os
import sys
import re

sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
from pyspark.sql import functions as F
from pyspark.sql.types import StringType


class MatchSensitive(Templates):

    def __init__(self):
        super().__init__()
        self.db_save = f'dwd_sensitive_match'
        self.spark = self.create_spark_object(
            app_name=f"{self.db_save}")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_sensitive = self.spark.sql(f"select 1+1;")
        self.description_df = self.spark.sql(f"select 1+1;")
        self.exploded_df = self.spark.sql(f"select 1+1;")
        self.partitions_num = 5

        # 自定义udf函数相关对象
        # self.find_sensitive = self.spark.udf.register("find_sensitive", self.find_sensitive, StringType())

    @staticmethod
    def find_sensitive(ele_list:list):
        pattern = re.compile(r'(?<!\+|\*|\-|\%|\.|\')\b({})\b'.format('|'.join([re.escape(x) for x in ele_list])),
                             flags=re.IGNORECASE)
        def udf_find_sensitive(match_text):
            ele_list = re.findall(pattern, match_text)
            if ele_list:
                return '||'.join(set(ele_list))
            else:
                return None
        return F.udf(udf_find_sensitive, StringType())


    def read_data(self):
        # 读取dim_sensitive获取敏感词列表
        sql = """
        select 
            sensitive 
        from 
            dim_sensitive;
        """
        print(sql)
        self.df_sensitive = self.spark.sql(sqlQuery=sql).cache()

        with open("/home/chenyuanjie/data/description.txt", 'r', encoding='utf-8') as file:
            description_list = file.readlines()

        # 转为 DataFrame
        self.description_df = self.spark.createDataFrame([(description,) for description in description_list],
                                                           ["description"])
        self.description_df.show()


    def handle_data(self):

        # self.df_save = self.description_df.withColumn("matched", F.concat(*[F.when(F.col("description").like(f"%{sensitive}%"),
        #                F.lit(sensitive)) for sensitive in self.df_sensitive.select("sensitive").rdd.flatMap(lambda x: x).collect()],
        #                F.lit("||")))
        # self.df_save.show()

        # self.exploded_df = self.description_df.withColumn("word", F.explode(F.split(F.col("description"), " ")))

        # 将 exploded_df 与 df_sensitive 进行 join,匹配单词和敏感词
        # self.df_save = self.exploded_df.join(self.df_sensitive, self.exploded_df.word == self.df_sensitive.sensitive)

        # sensitive_words = self.df_sensitive.select('sensitive').rdd.flatMap(lambda x: x).collect()
        # escaped_sensitive_words = [re.escape(word) for word in sensitive_words]

        # 构建正则表达式模式
        # pattern = f"({'|'.join(escaped_sensitive_words)})"
        # self.df_save = self.description_df.withColumn("sensitive", F.regexp_extract(F.col("description"), pattern, 0))

        # 根据原始 description 进行分组,将匹配到的敏感词用 || 拼接
        # self.df_save = self.df_save.groupBy("description").agg(F.collect_list("sensitive").alias("sensitive_list"))
        # self.df_save = self.df_save.withColumn("matched", F.concat_ws("||", "sensitive_list"))

        self.description_df = self.description_df.withColumn("description", F.regexp_replace(F.col("description"), "\r\n|\n", ""))

        # 使用 collect 方法将列数据转为本地 Python 列表
        sensitive_words = self.df_sensitive.toPandas()["sensitive"].values.tolist()

        self.df_save = self.description_df.withColumn("matched", self.find_sensitive(sensitive_words)(F.col("description")))

        # 选择需要的列
        # self.df_save = self.df_save.select("description", "matched")

        self.df_save.show()


if __name__ == '__main__':
    handle_obj = MatchSensitive()
    handle_obj.run()