import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from pyspark.sql.types import StringType, ArrayType
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from pyspark.sql.window import Window

"""
根据merchantwords计算词频
"""


def is_number(str):
    """
    判断一个字符是否是数字
    :param str:
    :return:
    """
    import re
    return re.match(r"^-?\d+\.?\d+$", str) is not None


class DimMWordFrequency:

    def __init__(self, site_name, batch):
        assert site_name is not None, "site_name not null"
        assert batch is not None, "batch not null"
        self.site_name = site_name
        self.batch = batch
        pass

    @staticmethod
    def word_tokenize(keyword: str):
        import re
        keyword = keyword.replace("+", " ")
        keyword = re.sub(r'(\d+\.?\d*|-|\"|,|,|?|\?|/|、|)', '', keyword).strip()

        from nltk.tokenize import word_tokenize
        result = word_tokenize(keyword, "english")
        # 过滤标点如下
        filter_arr = [
            " ", "\t", "\r", "\n", "(", ")", ",", ",", "[", "]", "、", "-", ":", "&", "|", "+", "``", "'", "'", "\""
        ]
        return list(filter(lambda x: not is_number(x) and x not in filter_arr, result))

    @staticmethod
    def word_stem(keyword: str):
        from nltk.stem.snowball import SnowballStemmer
        stemmer = SnowballStemmer("english", ignore_stopwords=False)
        return stemmer.stem(keyword)

    def run(self):
        """
        单词纬度
        :return:
        """
        spark = SparkUtil.get_spark_session("DimWordFrequency")
        udf_word_tokenize = F.udf(self.word_tokenize, ArrayType(StringType()))

        keywords_all = spark.sql(
            f"select keyword from dwt_merchantwords_st_detail where site_name='{self.site_name}' and batch ='{self.batch}'").cache()
        df_all = keywords_all.withColumn("word", F.explode(udf_word_tokenize(F.col("keyword"))))
        df_all = df_all.groupby(F.col("word")) \
            .agg(F.count("word").alias("frequency")) \
            .withColumn("rank", F.rank().over(Window.orderBy(F.col("frequency").desc()))) \
            .orderBy(F.col("rank").asc()) \
            .select(
            F.col("word"),
            F.col("frequency"),
            F.col("rank"),
            F.lit({self.site_name}).alias("site_name"),
            F.lit({self.batch}).alias("batch")
        )

        hive_tb = 'dim_m_word_frequency'
        partition_dict = {
            "site_name": self.site_name,
            "batch": self.batch,
        }
        CommonUtil.save_or_update_table(
            spark,
            hive_tb,
            partition_dict,
            df_all,
            True
        )
        print("success")
        pass

    def run_stem(self):
        """
        词根纬度
        :return:
        """
        spark = SparkUtil.get_spark_session("DimWordFrequency")
        udf_word_tokenize = F.udf(self.word_tokenize, ArrayType(StringType()))
        udf_word_stem = F.udf(self.word_stem, StringType())

        keywords_all = spark.sql("select keyword from dwt_merchantwords_st_detail where site_name='us'").cache()
        df_all = keywords_all.withColumn("word", F.explode(udf_word_tokenize(F.col("keyword"))))
        df_all = df_all.groupby(F.col("word")) \
            .agg(F.count("word").alias("frequency"))

        df_all = df_all.withColumn("word_stem", udf_word_stem(F.col("word")))

        df_all = df_all.groupby(F.col("word_stem")) \
            .agg(
            F.sum("frequency").alias("frequency"),
            F.concat_ws(',', F.collect_set(F.col("word").alias('word_list')))
        ) \
            .withColumn("rank", F.rank().over(Window.orderBy(F.col("frequency").desc()))) \
            .orderBy(F.col("rank").asc()) \
            .select(
            F.col("word_stem"),
            F.col("word_list"),
            F.col("frequency"),
            F.col("rank"),
            F.lit("us").alias("site_name")
        )

        hive_tb = 'dim_m_word_stem_frequency'
        # #  去重
        partition_dict = {
            "site_name": "us"
        }
        hdfs_path = CommonUtil.build_hdfs_path(hive_tb, partition_dict)
        HdfsUtils.delete_hdfs_file(hdfs_path)
        partition_by = list(partition_dict.keys())
        print(f"当前存储的表名为:{hive_tb},分区为{partition_by}", )
        df_all.write.saveAsTable(name=hive_tb, format='hive', mode='append', partitionBy=partition_by)
        print("success")
        pass

    pass


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, 'us')
    batch = CommonUtil.get_sys_arg(2, '2024-1')
    handle_obj = DimMWordFrequency(site_name, batch)
    handle_obj.run()
    pass