import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil

from utils.spark_util import SparkUtil
from pyspark.sql import functions as F, Window, DataFrame
from pyspark.sql.types import IntegerType, StringType, MapType

"""
merchantwords 搜索词数据集成并去重输出
"""


class DwtMerchantwordsStDetail(object):

    def __init__(self, site_name, batch):
        assert site_name is not None, "site_name not null"
        assert batch is not None, "batch not null"

        self.site_name = site_name
        self.batch = batch
        app_name = f"{self.__class__.__name__}:{self.site_name}::{self.batch}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.hive_tb = 'dwt_merchantwords_st_detail'
        self.udf_check_utf8_and_convert_reg = F.udf(self.check_utf8_and_convert, StringType())
        pass

    @staticmethod
    def check_utf8_and_convert(word: str):
        """
        判断一个英文字符是否是utf-8编码的 即不包含乱码字符及中文 ,如果包含的话则从gbk编码转为utf-8编码
        :param str:
        :return:
        """
        import re
        # 检查是否包含有中文 如果有中文的话大概率是乱码 从gbk转为utf-8
        pattern = re.compile(r'[\u4e00-\u9fa5]')
        if pattern.match(word) is not None:
            try:
                return word.encode("gbk").decode("utf-8")
            except:
                return word
        else:
            return word

    @staticmethod
    def udf_detect_phrase_reg(lang_word_map):
        def detect_phrase(phrase: str):
            import re
            # + 号替换为空格用于分词
            phrase = re.sub(r'(\+)', ' ', phrase).strip()
            # 分词
            from nltk.tokenize import word_tokenize
            wordList = list(filter(lambda x: len(x) >= 2, word_tokenize(phrase, "english")))
            tmp_map = {
                "en": {"frequency": 0, "word": []},
                "fr": {"frequency": 0, "word": []},
                "es": {"frequency": 0, "word": []},
                "de": {"frequency": 0, "word": []},
            }
            for word in wordList:
                lang_rank_map: dict = lang_word_map.get(word)
                if lang_rank_map is not None:
                    for lang in lang_rank_map.keys():
                        frequency = lang_rank_map[lang]
                        tmp_map[lang]["frequency"] = tmp_map[lang]["frequency"] + frequency
                        tmp_map[lang]["word"].append(word)
                pass

            #  先根据word名称个数倒序后根据分数
            lang, hint_word_map = sorted(tmp_map.items(), key=lambda it: (len(it[1]['word']), it[1]['frequency']), reverse=True)[0]

            if hint_word_map['frequency'] == 0:
                return {"lang": None, "hint_word": None}
            else:
                hint_word_list = hint_word_map['word']
                hint_word = " ".join(hint_word_list)
                if len(hint_word) <= 2:
                    return {"lang": None, "hint_word": None}
                return {"lang": lang, "hint_word": hint_word}
            pass

        return F.udf(detect_phrase, MapType(StringType(), StringType()))

    def handle_calc_lang(self, df_all: DataFrame) -> DataFrame:
        lang_word_list = self.spark.sql("""
        select word, langs
        from big_data_selection.tmp_lang_word_frequency
    """).collect()
        # lang_word_df => 转为map
        lang_word_map = {row['word']: row['langs'] for row in lang_word_list}

        df_lang_all = df_all.withColumn("lang",
                                        F.coalesce(self.udf_detect_phrase_reg(lang_word_map)(F.col("keyword")).getField("lang"),
                                                   F.lit("other")))
        return df_lang_all

    def handle_update(self, date_info):
        df_append = self.spark.sql(f"""
                select keyword,
                       volume,
                       avg_3m,
                       avg_12m,
                       depth,
                       results_count,
                       sponsored_ads_count,
                       page_1_reviews,
                       appearance,
                       last_seen,
                       update_time,
                       'csv' as source_type,
                        null as api_json,
                       'us' as site_name
                from ods_merchantwords_st_detail_append
                where date_info = '{date_info}'
""").cache()
        # 判断是否需要更新
        assert df_append.count() > 0, "无需更新,请检查【ods_merchantwords_st_detail_append】数据是否异常"

        df_exist = self.spark.sql(f"""
            select keyword,
                   volume,
                   avg_3m,
                   avg_12m,
                   depth,
                   results_count,
                   sponsored_ads_count,
                   page_1_reviews,
                   appearance,
                   last_seen,
                   update_time,
                   source_type,
                   api_json,
                   site_name
            from dwt_merchantwords_st_detail;
""").cache()

        df_all = df_exist.unionByName(df_append)
        self.handle_save(df_all)

        pass

    def handle_save(self, df_all: DataFrame):
        # 处理乱码
        df_all = df_all.withColumn("keyword", F.trim(self.udf_check_utf8_and_convert_reg(F.col("keyword"))))
        #  keyword 开窗 时间倒序 取最新一个
        df_all = df_all.withColumn("row_number", F.row_number().over(window=Window.partitionBy(['keyword']).orderBy(
            F.col("update_time").desc()
        )))
        # 去重
        df_all = df_all.where("row_number == 1 ")
        df_all = df_all.drop(F.col("row_number")).cache()
        # 语种识别
        df_all = self.handle_calc_lang(df_all)
        # 分区重置
        df_all = df_all.repartition(60)

        # #  去重
        partition_dict = {
            "site_name": self.site_name,
            "batch": self.batch,
        }
        # 保存或更新
        CommonUtil.save_or_update_table(
            spark_session=self.spark,
            hive_tb_name=self.hive_tb,
            partition_dict=partition_dict,
            df_save=df_all,
            drop_exist_tmp_flag=False
        )
        print("success")
        pass

    def run(self):
        tb_ods = "tmp_merchantwords_st_detail_2024"
        if self.site_name != 'us':
            tb_ods = f"{site_name}_{tb_ods}"
            pass

        df_all = self.spark.sql(f"""
        select  
            keyword,
            volume,
            avg_3m,
            avg_12m,
            depth,
            results_count,
            sponsored_ads_count,
            page_1_reviews,
            appearance,
            last_seen,
            update_time,
            source_type,
            api_json,
            site_name
        from {tb_ods}
""")
        #  schema
        schema_api = F.schema_of_json("""
            {"historyArray": [1,2],
        "country": 1,
        "avg12Month": 1015775,
        "last_seen": 20230801,
        "totalReviewsFirstPage": 62592,
        "avg3Month": 1249957,
        "volume": 134090,
        "categoryIdArray": [
            "toys-and-games"
        ],
        "depth": 3,
        "appearance": "Evergreen",
        "popularity": 5,
        "resultsCount": 1000,
        "sponsoredCount": -1,
        "keyword": "electric scooter",
        "inLatestBatch": true
    }""")
        #         schema_csv = F.schema_of_json("""
        # {"volume":"9","depth":"8","appearance":"Rediscovered","phrase":"weaning rings for goats","sponsored ads":"0","page 1 reviews":"1914","3m avg":"3","12m avg":"1","results":"1"}
        #             """)

        df_all = df_all.withColumn("json_api", F.from_json(F.col("api_json"), schema_api))

        df_all = df_all.select(
            F.col("keyword").alias("keyword"),
            #  解析和兼容
            F.coalesce(F.col("json_api.volume"), F.col("volume")).cast(IntegerType()).alias("volume"),
            F.coalesce(F.col("json_api.avg3Month"), F.col("avg_3m")).cast(IntegerType()).alias("avg_3m"),
            F.coalesce(F.col("json_api.avg12Month"), F.col("avg_12m")).cast(IntegerType()).alias("avg_12m"),
            F.coalesce(F.col("json_api.depth"), F.col("depth")).cast(IntegerType()).alias("depth"),
            F.coalesce(F.col("json_api.resultsCount"), F.col("results_count")).cast(IntegerType()).alias("results_count"),
            F.coalesce(F.col("json_api.sponsoredCount"), F.col("sponsored_ads_count")).cast(IntegerType()).alias(
                "sponsored_ads_count"),
            F.coalesce(F.col("json_api.totalReviewsFirstPage"), F.col("page_1_reviews")).cast(IntegerType()).alias("page_1_reviews"),
            F.coalesce(F.col("json_api.appearance"), F.col("appearance")).alias("appearance"),
            F.coalesce(F.col("json_api.last_seen")).alias("last_seen"),

            F.col("update_time").alias("update_time"),
            F.col("source_type"),
            F.col("api_json"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.batch).alias("batch")
        ).cache()
        #  保存
        self.handle_save(df_all)


if __name__ == '__main__':
    update_flag = CommonUtil.get_sys_arg(1, None)
    if update_flag == 'update':
        site_name = CommonUtil.get_sys_arg(2, 'us')
        date_info = CommonUtil.get_sys_arg(3, CommonUtil.format_now('%Y-%m-%d'))
        obj = DwtMerchantwordsStDetail('us', '2024-1').handle_update(date_info)
    else:
        site_name = CommonUtil.get_sys_arg(1, 'us')
        batch = CommonUtil.get_sys_arg(2, '2024-1')
        DwtMerchantwordsStDetail(site_name, batch).run()