import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql.types import StringType, MapType
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
import numpy as np


class DwtStMarket(object):
    def __init__(self, site_name, date_type, date_info):
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.last_12_month = self.get_last_12_month()

    @staticmethod
    def udf_market_fun_reg(last_12_month: list):
        def udf_market_fun(rows: list):
            index_map = {}
            for i, val in enumerate(last_12_month):
                index_map[val] = str(i)
                pass

            row_map = {}
            for row in rows:
                row_map[index_map[row['date_info']]] = row
                pass

            # 默认值都是一般市场
            market_cycle_type = 4
            # 计算市场
            max_avg_val = 0

            result_val = {}
            rate_list = []
            for key in row_map:
                month_seq = int(key) + 1
                row = row_map.get(key)

                vol_val = row['st_search_num']
                rank_val = row['st_rank']
                if rank_val is not None:
                    result_val[f"rank_{month_seq}"] = str(rank_val)
                    pass

                if vol_val is not None:
                    result_val[f"rank_{month_seq}_search_volume"] = str(vol_val)
                    pass

                # 计算最近6个月的月搜索量增长率
                if month_seq < 6:
                    last_month_row = row_map.get(str(int(key) + 1))
                    if last_month_row is not None:
                        last_vol_val = last_month_row['st_search_num']
                        if last_vol_val is not None and vol_val is not None:
                            # 增长率=(本月-上月)/上月
                            rate = round((float(vol_val) - float(last_vol_val)) / float(last_vol_val), 4)
                            result_val[f"month_growth_rate_{month_seq}"] = str(rate)
                            rate_list.append(rate)
                        pass

                value = [row['st_search_num'] for row in rows]
                # 平均值
                avg_val = round(np.sum(value) / 12, 4)
                # 最大值
                max_val = np.max(value)
                max_avg_val = round((max_val - avg_val) / avg_val, 4)

                if max_avg_val > 0.8:
                    # 1 季节性市场:该关键词最近12个月的 月搜索量,按照搜索量降序排列,取得最高搜索量 h1,如果(h1-平均值)/平均值>0.8,则该市场为季节性市场
                    market_cycle_type = 1
                    pass
                else:
                    #  2 持续增长市场:该关键词最近6个月有5-6个月,月搜索量增长率大于5%的市场
                    if len([rate for rate in rate_list if rate >= 0.05]) >= 5:
                        market_cycle_type = 2
                    #  3:持续衰退市场:词最近6个月有5-6个月,月搜索量增长率小于-5%的市场
                    elif len([rate for rate in rate_list if rate <= -0.05]) >= 5:
                        market_cycle_type = 3
                    else:
                        market_cycle_type = 4
                pass

            result_val['market_cycle_type'] = str(market_cycle_type)
            result_val['max_avg_val'] = str(max_avg_val)
            result_val['start_month'] = str(last_12_month[0])
            result_val['end_month'] = str(last_12_month[11])

            return result_val

        return F.udf(udf_market_fun, MapType(StringType(), StringType()))

    def get_last_12_month(self):
        last_12_month = []
        for i in range(0, 12):
            last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i))
        return last_12_month

    def run(self):
        sql = f"""
            select 
                search_term,
                st_search_num,
                date_info,
                st_rank
            from dim_st_detail
            where site_name = '{self.site_name}'
            and date_type = '{self.date_type}'
            and date_info in ({CommonUtil.list_to_insql(self.last_12_month)})
            and st_search_num is not null
            and search_term = 'iphone 16 pro screen protector';
        """
        df_st_detail = self.spark.sql(sql).cache()
        df_st_detail.show(10, truncate=False)
        df_st_detail = df_st_detail.groupby('search_term').agg(
            self.udf_market_fun_reg
            (self.last_12_month)
            (F.collect_list(F.struct("date_info", "st_search_num", "st_rank"))).alias("cal_map")
        )
        df_save = df_st_detail.select(
            F.col("search_term"),
            F.col("cal_map").getField("start_month").alias("start_month"),
            F.col("cal_map").getField("end_month").alias("end_month"),
            F.col("cal_map").getField("market_cycle_type").alias("market_cycle_type"),
            F.col("cal_map").getField("rank_1_search_volume").alias("rank_1_search_volume"),
            F.col("cal_map").getField("rank_2_search_volume").alias("rank_2_search_volume"),
            F.col("cal_map").getField("rank_3_search_volume").alias("rank_3_search_volume"),
            F.col("cal_map").getField("rank_4_search_volume").alias("rank_4_search_volume"),
            F.col("cal_map").getField("rank_5_search_volume").alias("rank_5_search_volume"),
            F.col("cal_map").getField("rank_6_search_volume").alias("rank_6_search_volume"),
            F.col("cal_map").getField("rank_7_search_volume").alias("rank_7_search_volume"),
            F.col("cal_map").getField("rank_8_search_volume").alias("rank_8_search_volume"),
            F.col("cal_map").getField("rank_9_search_volume").alias("rank_9_search_volume"),
            F.col("cal_map").getField("rank_10_search_volume").alias("rank_10_search_volume"),
            F.col("cal_map").getField("rank_11_search_volume").alias("rank_11_search_volume"),
            F.col("cal_map").getField("rank_12_search_volume").alias("rank_12_search_volume"),
            F.col("cal_map").getField("max_avg_val").alias("max_avg_val")
        )
        df_save.show()


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwtStMarket(site_name, date_type, date_info)
    obj.run()