import os
import sys


sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil, DateTypes
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType
from pyspark.sql.dataframe import DataFrame
from yswg_utils.common_udf import udf_detect_phrase_reg


class DwtAbaLast365(object):

    def __init__(self, site_name, date_type, date_info):
        self.site_name = site_name
        self.date_info = date_info
        self.date_type = date_type
        self.date_type_original = DateTypes.month.name
        assert date_type in [DateTypes.month.name, DateTypes.year.name], "date_type 不合法！！"
        app_name = f"{self.__class__.__name__}:{self.site_name}:{self.date_type}:{self.date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.hive_tb = "dwt_aba_last365"
        pass

    # def get_last_year_month(self, year_month):
    #     """
    #     获取同比年的上一年最新数据
    #     :param year_month:
    #     :return:
    #     """
    #     start = year_month
    #     end = CommonUtil.get_month_offset(start, -12)
    #     year = int(CommonUtil.safeIndex(end.split("-"), 0))
    #     flag = True
    #
    #     while flag:
    #         start = CommonUtil.get_month_offset(start, -1)
    #         end = CommonUtil.get_month_offset(end, -1)
    #         tmp = int(CommonUtil.safeIndex(end.split("-"), 0))
    #         flag = (tmp != (year - 1))
    #     return start

    #  对指定的行进行行转列
    def pivot_df(self, last_12_month: list, df_all: DataFrame, df_agg: DataFrame, group_col: str, pivot_col: str, agg_col_arr: list):
        """
        对指定的行进行行转列
        """
        # 列名如下: 2024-07_st_num|2024-06_st_num|……|2024-07_bsr_orders|2024-06_bsr_orders|……
        df_tmp = df_all.groupBy(group_col).pivot(pivot_col, last_12_month).agg(
            *list(map(lambda col: F.first(col).alias(col), agg_col_arr))
        ).cache()

        # 列名如下: st_num1|st_num2|……|bsr_orders1|bsr_orders2|……
        for index in range(0, len(last_12_month)):
            for col in agg_col_arr:
                prefix = last_12_month[index]
                month = int(prefix.split('-')[-1])
                df_tmp = df_tmp.withColumnRenamed(f"{prefix}_{col}", f"{col}{month}")
        return df_agg.join(df_tmp, group_col, "inner")

    # 语种处理
    def handle_calc_lang(self, df_all: DataFrame) -> DataFrame:
        sql = """
            select 
                word, 
                langs 
            from big_data_selection.tmp_lang_word_frequency;
        """
        lang_word_list = self.spark.sql(sql).collect()
        # 转为map
        lang_word_map = {row['word']: row['langs'] for row in lang_word_list}

        df_all = df_all.withColumn(
            "lang",
            F.coalesce(udf_detect_phrase_reg(lang_word_map)(F.col("search_term")).getField("lang"), F.lit("other"))
        )
        return df_all

    def run(self):
        #  获取过去的12个月
        last_12_month = []
        for i in range(0, 12):
            last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i))
        print(f"过去12个月为{last_12_month}")

        sql = f"""
            select 
                search_term,
                id,
                st_bsr_cate_1_id_new as category_id,
                bsr_orders,
                search_volume,
                st_ao_avg,
                st_ao_val_rate,
                price_avg,
                weight_avg,
                volume_avg,
                rating_avg,
                total_comments_avg,
                total_asin_num,
                aadd_proportion,
                sp_proportion,
                fbm_proportion,
                cn_proportion,
                amzon_proportion,
                top3_seller_orders,
                top3_seller_bsr_orders,
                top3_brand_orders,
                top3_brand_bsr_orders,
                page3_brand_num,
                page3_seller_num,
                new_bsr_orders_proportion,
                new_asin_proportion,
                supply_demand,
                max_num,
                most_proportion,
                gross_profit_fee_sea,
                gross_profit_fee_air,
                st_bsr_cate_current_id_new as category_current_id,
                color_proportion,
                max_num_asin,
                is_self_max_num_asin,
                multi_color_proportion,
                multi_size_proportion,
                is_new_market_segment,
                market_cycle_type,
                brand_monopoly,
                seller_monopoly,
                is_first_text,
                is_ascending_text,
                is_search_text,
                st_word_num,
                date_info,
                st_num
            from dwt_aba_st_analytics
            where site_name = '{self.site_name}'
              and date_type = '{self.date_type_original}'
              and date_info in ({CommonUtil.list_to_insql(last_12_month)})
        """
        df_all = self.spark.sql(sql).repartition(80, "search_term", "date_info").cache()

        if date_info < '2023-10':
            old_list = list(filter(lambda it: it < '2022-10', last_12_month))
            new_list = list(filter(lambda it: it >= '2022-10', last_12_month))
            sql2 = f"""
                select 
                    search_term,
                    st_search_sum as orders,
                    date_info
                from dim_st_detail
                where site_name = '{self.site_name}'
                  and date_type = 'month_old'
                  and date_info in ({CommonUtil.list_to_insql(old_list)})
                union
                select 
                    search_term,
                    orders as orders,
                    date_info
                from dwt_aba_st_analytics
                where site_name = '{self.site_name}'
                  and date_type = 'month'
                  and date_info in ({CommonUtil.list_to_insql(new_list)})
            """
        else:
            sql2 = f"""
                select 
                    search_term,
                    orders as orders,
                    date_info
                from dwt_aba_st_analytics
                where site_name = '{self.site_name}'
                  and date_type = 'month'
                  and date_info in ({CommonUtil.list_to_insql(last_12_month)})
            """
        df_orders_all = self.spark.sql(sql2).repartition(80, "search_term", "date_info").cache()
        df_all = df_all.join(
            other=df_orders_all, on=["search_term", "date_info"], how="left"
        ).cache()
        df_orders_all.unpersist()

        df_agg = df_all.groupBy("id").agg(
            F.first("search_term").alias("search_term"),
            F.first("category_id").alias("category_id"),
            F.first("category_current_id").alias("category_current_id"),
            F.expr("round(sum(st_ao_avg)/12,4)").alias("st_ao_avg"),
            F.expr("round(sum(st_ao_val_rate)/12,4)").alias("st_ao_val_rate"),
            F.expr("round(sum(price_avg)/12,4)").alias("price_avg"),
            F.expr("round(sum(weight_avg)/12,4)").alias("weight_avg"),
            F.expr("round(sum(volume_avg)/12,4)").alias("volume_avg"),
            F.expr("round(sum(rating_avg)/12,4)").alias("rating_avg"),
            F.expr("round(sum(aadd_proportion)/12,4)").alias("aadd_proportion"),
            F.expr("round(sum(sp_proportion)/12,4)").alias("sp_proportion"),
            F.expr("round(sum(fbm_proportion)/12,4)").alias("fbm_proportion"),
            F.expr("round(sum(cn_proportion)/12,4)").alias("cn_proportion"),
            F.expr("round(sum(amzon_proportion)/12,4)").alias("amzon_proportion"),
            F.expr("round(sum(top3_seller_orders)/12,4)").alias("top3_seller_orders"),
            F.expr("round(sum(top3_seller_bsr_orders)/12,4)").alias("top3_seller_bsr_orders"),
            F.expr("round(sum(top3_brand_orders)/12,4)").alias("top3_brand_orders"),
            F.expr("round(sum(top3_brand_bsr_orders)/12,4)").alias("top3_brand_bsr_orders"),
            F.expr("round(sum(page3_brand_num)/12,4)").alias("page3_brand_num"),
            F.expr("round(sum(page3_seller_num)/12,4)").alias("page3_seller_num"),
            F.expr("round(sum(new_bsr_orders_proportion)/12,4)").alias("new_asin_bsr_orders_avg_monopoly"),
            F.expr("round(sum(new_asin_proportion)/12,4)").alias("new_asin_num_avg_monopoly"),
            F.expr("round(sum(supply_demand)/12,4)").alias("supply_demand"),
            # 12个月的平均评论数
            F.expr("round(sum(total_comments_avg)/12,4)").alias("total_comments_avg"),
            #  多数量占比平均
            F.expr("round(sum(most_proportion)/12,4)").alias("most_avg_proportion"),
            # 多颜色占比
            F.expr("round(sum(color_proportion)/12,4)").alias("color_proportion"),
            F.expr("round(sum(gross_profit_fee_sea)/12,4)").alias("gross_profit_fee_sea"),
            F.expr("round(sum(gross_profit_fee_air)/12,4)").alias("gross_profit_fee_air"),
            F.expr("round(sum(brand_monopoly)/12,4)").alias("brand_monopoly"),

            F.expr("round(sum(seller_monopoly)/12,4)").alias("seller_monopoly"),
            F.expr("round(sum(multi_size_proportion)/12,4)").alias("multi_size_avg_proportion"),
            F.expr("round(sum(multi_color_proportion)/12,4)").alias("multi_color_avg_proportion"),

            #  sum
            F.sum(F.col("bsr_orders")).alias("bsr_orders"),
            F.sum(F.col("orders")).alias("orders"),
            F.sum(F.col("st_num")).alias("total_st_num"),
            F.sum(F.col("total_asin_num")).alias("total_asin_num"),

            # 最大数量  max_num 较大值 对应的行 从第一个开始比较
            F.max(F.struct("max_num", "max_num_asin", "is_self_max_num_asin")).alias("max_num_row"),

            # bsr销量最高对应的月
            F.max(F.struct("bsr_orders", "date_info")).alias("tmp_row_1"),
            # 销量最高对应的月
            F.max(F.struct("orders", "date_info")).alias("tmp_row_2"),

            # 是否新细分市场 非平均数算法 12个月都是新出现 表明同比年也是新出现 即 sum=12 表示为1 否则都是0
            F.avg("is_new_market_segment").cast(IntegerType()).alias("is_new_market_segment"),
            # 同比是否是热搜词 热搜词：最近1月/年中，出现的次数大于80% 如果月热搜词 is_search_text的和>=10 则是热搜词
            F.expr("sum(is_search_text) / 9.6").cast(IntegerType()).alias("is_search_text"),
            F.max("st_word_num").alias("st_word_num")
        ).cache()

        # #  行转列
        agg_col_arr = [
            'st_num',
            'bsr_orders',
            'orders',
            'market_cycle_type',
            'search_volume',
        ]

        df_all = self.pivot_df(last_12_month, df_all, df_agg, "id", "date_info", agg_col_arr)
        # 12个月前
        last_year_month = CommonUtil.get_month_offset(self.date_info, -12)
        print(f"12个月前是：{last_year_month}")

        #  获取搜索量排名 本次和上年度排名
        sql = f"""
            select 
                search_term_id,
                collect_set(rank)[0]      as rank,
                collect_set(last_rank)[0] as last_rank
            from (
                select 
                    search_term_id,
                    case date_info when '{self.date_info}' then sv_rank end as rank,
                    case date_info when '{last_year_month}' then sv_rank end as last_rank
                from dwt_st_sv_last365
                where site_name = '{self.site_name}'
                  and date_info in ('{self.date_info}', '{last_year_month}')
            ) tmp
            group by search_term_id;
        """
        df_st_total_rank = self.spark.sql(sql).na.fill({
            "last_rank": 0
        }).cache()

        df_all = df_all.join(
            df_st_total_rank, df_all['id'].eqNullSafe(df_st_total_rank['search_term_id']), "left"
        ).withColumn(
            "is_ascending_text",
            F.expr("rank / last_rank <= 0.5").cast(IntegerType())
        )

        # 语种识别
        df_all = self.handle_calc_lang(df_all)

        df_all = df_all.select(
            df_all['id'],
            df_all['search_term'],
            F.col("category_id"),
            F.col("category_current_id"),
            F.col('rank').cast(IntegerType()),
            F.col("total_st_num"),

            # 后缀数字表示对应的月份，如：2024-07则为st_num7，2023-12则为st_num12。下面字段同理
            F.col("st_num1").cast(IntegerType()),
            F.col("st_num2").cast(IntegerType()),
            F.col("st_num3").cast(IntegerType()),
            F.col("st_num4").cast(IntegerType()),
            F.col("st_num5").cast(IntegerType()),
            F.col("st_num6").cast(IntegerType()),
            F.col("st_num7").cast(IntegerType()),
            F.col("st_num8").cast(IntegerType()),
            F.col("st_num9").cast(IntegerType()),
            F.col("st_num10").cast(IntegerType()),
            F.col("st_num11").cast(IntegerType()),
            F.col("st_num12").cast(IntegerType()),

            F.col("orders1").cast(IntegerType()),
            F.col("orders2").cast(IntegerType()),
            F.col("orders3").cast(IntegerType()),
            F.col("orders4").cast(IntegerType()),
            F.col("orders5").cast(IntegerType()),
            F.col("orders6").cast(IntegerType()),
            F.col("orders7").cast(IntegerType()),
            F.col("orders8").cast(IntegerType()),
            F.col("orders9").cast(IntegerType()),
            F.col("orders10").cast(IntegerType()),
            F.col("orders11").cast(IntegerType()),
            F.col("orders12").cast(IntegerType()),

            F.col("bsr_orders1").cast(IntegerType()),
            F.col("bsr_orders2").cast(IntegerType()),
            F.col("bsr_orders3").cast(IntegerType()),
            F.col("bsr_orders4").cast(IntegerType()),
            F.col("bsr_orders5").cast(IntegerType()),
            F.col("bsr_orders6").cast(IntegerType()),
            F.col("bsr_orders7").cast(IntegerType()),
            F.col("bsr_orders8").cast(IntegerType()),
            F.col("bsr_orders9").cast(IntegerType()),
            F.col("bsr_orders10").cast(IntegerType()),
            F.col("bsr_orders11").cast(IntegerType()),
            F.col("bsr_orders12").cast(IntegerType()),

            F.col("market_cycle_type1").cast(IntegerType()),
            F.col("market_cycle_type2").cast(IntegerType()),
            F.col("market_cycle_type3").cast(IntegerType()),
            F.col("market_cycle_type4").cast(IntegerType()),
            F.col("market_cycle_type5").cast(IntegerType()),
            F.col("market_cycle_type6").cast(IntegerType()),
            F.col("market_cycle_type7").cast(IntegerType()),
            F.col("market_cycle_type8").cast(IntegerType()),
            F.col("market_cycle_type9").cast(IntegerType()),
            F.col("market_cycle_type10").cast(IntegerType()),
            F.col("market_cycle_type11").cast(IntegerType()),
            F.col("market_cycle_type12").cast(IntegerType()),

            F.col("search_volume1").cast(IntegerType()),
            F.col("search_volume2").cast(IntegerType()),
            F.col("search_volume3").cast(IntegerType()),
            F.col("search_volume4").cast(IntegerType()),
            F.col("search_volume5").cast(IntegerType()),
            F.col("search_volume6").cast(IntegerType()),
            F.col("search_volume7").cast(IntegerType()),
            F.col("search_volume8").cast(IntegerType()),
            F.col("search_volume9").cast(IntegerType()),
            F.col("search_volume10").cast(IntegerType()),
            F.col("search_volume11").cast(IntegerType()),
            F.col("search_volume12").cast(IntegerType()),

            # 平均
            F.col("st_ao_avg"),
            F.col("st_ao_val_rate"),
            F.col("price_avg"),
            F.col("weight_avg"),
            F.col("volume_avg"),
            F.col("rating_avg"),
            F.col("total_comments_avg"),

            F.col("multi_size_avg_proportion"),
            F.col("multi_color_avg_proportion"),

            F.col("brand_monopoly"),
            F.col("seller_monopoly"),
            # 平均值
            F.col("most_avg_proportion"),
            F.col("supply_demand"),
            F.col("aadd_proportion"),
            F.col("sp_proportion"),
            F.col("fbm_proportion"),
            F.col("cn_proportion"),
            F.col("amzon_proportion"),
            F.col("top3_seller_orders"),
            F.col("top3_seller_bsr_orders"),
            F.col("top3_brand_orders"),
            F.col("top3_brand_bsr_orders"),
            F.col("page3_brand_num").cast(IntegerType()),
            F.col("page3_seller_num").cast(IntegerType()),
            F.col("new_asin_num_avg_monopoly"),
            F.col("new_asin_bsr_orders_avg_monopoly"),
            F.col("orders").cast(IntegerType()),
            F.col("bsr_orders"),

            F.col("gross_profit_fee_sea"),
            F.col("gross_profit_fee_air"),
            F.col("color_proportion"),

            F.col("total_asin_num"),

            # 最大数量对应的行
            F.col("max_num_row").getField("max_num").alias("max_num"),
            F.col("max_num_row").getField("max_num_asin").alias("max_num_asin"),
            F.col("max_num_row").getField("is_self_max_num_asin").alias("is_self_max_num_asin"),

            F.col("tmp_row_1").getField("date_info").alias("max_bsr_orders_month"),
            F.col("tmp_row_2").getField("date_info").alias("max_orders_month"),

            F.col("is_new_market_segment").alias("is_new_market_segment"),

            # 同比是否是新增词，0：否，1：是  本月（周/年）同比上月（周/年）有新出现的搜索词
            F.col("is_new_market_segment").alias("is_first_text"),

            F.col("is_ascending_text").alias("is_ascending_text"),
            F.col("is_search_text").alias("is_search_text"),
            F.col("st_word_num").alias("st_word_num"),

            F.current_date().alias("updated_time").cast(StringType()),
            F.current_date().alias("created_time").cast(StringType()),

            F.lit(None).alias("usr_mask_type"),
            F.lit(None).alias("usr_mask_progress"),
            F.col("lang"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_type).alias("date_type"),
            F.lit(self.date_info).alias("date_info")
        )

        # 四个季度bsr销量
        df_all = df_all.withColumn(
            "q1_bsr_orders",
            F.expr("coalesce(bsr_orders1,0) + coalesce(bsr_orders2,0) + coalesce(bsr_orders3,0)")
        ).withColumn(
            "q2_bsr_orders",
            F.expr("coalesce(bsr_orders4,0) + coalesce(bsr_orders5,0) + coalesce(bsr_orders6,0)")
        ).withColumn(
            "q3_bsr_orders",
            F.expr("coalesce(bsr_orders7,0) + coalesce(bsr_orders8,0) + coalesce(bsr_orders9,0)")
        ).withColumn(
            "q4_bsr_orders",
            F.expr("coalesce(bsr_orders10,0) + coalesce(bsr_orders11,0) + coalesce(bsr_orders12,0)")
        )
        #  四个季度
        df_all = df_all.withColumn(
            "q1_orders",
            F.expr("coalesce(orders1,0) + coalesce(orders2,0) + coalesce(orders3,0)")
        ).withColumn(
            "q2_orders",
            F.expr("coalesce(orders4,0) + coalesce(orders5,0) + coalesce(orders6,0)")
        ).withColumn(
            "q3_orders",
            F.expr("coalesce(orders7,0) + coalesce(orders8,0) + coalesce(orders9,0)")
        ).withColumn(
            "q4_orders",
            F.expr("coalesce(orders10,0) + coalesce(orders11,0) + coalesce(orders12,0)")
        )

        # top_rank 兼容
        df_all = df_all.withColumn(
            "top_rank", F.col("rank")
        ).na.fill({
            "rank": 0,
            "top_rank": 0
        })

        #  重新分区
        df_all = df_all.repartition(15)

        partition_by = ["site_name", "date_type", "date_info"]
        print(f"当前存储的表名为：{self.hive_tb},分区为{partition_by}", )
        hdfs_path = CommonUtil.build_hdfs_path(
            self.hive_tb,
            partition_dict={
                "site_name": self.site_name,
                "date_type": self.date_type,
                "date_info": self.date_info,
            }
        )
        print(f"清除hdfs目录中:{hdfs_path}")
        HdfsUtils.delete_file_in_folder(hdfs_path)
        df_all.write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=partition_by)
        print("success")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwtAbaLast365(site_name=site_name, date_type=date_type, date_info=date_info)
    obj.run()
