import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql.types import StringType, MapType
from utils.common_util import CommonUtil, DateTypes
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from utils.templates import Templates
import numpy as np

"""
获取搜索词 市场周期类型 相关指标
依赖 dim_st_detail 表 输出为 Dwt_st_market
"""


class DwtStMarket(Templates):
    def __init__(self, site_name, date_type, date_info):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)

        self.last_12_month = self.get_last_12_month()
        self.df_st = self.spark.sql("select 1+1;")
        self.df_st_detail = self.spark.sql("select 1+1;")
        self.df_st_id = self.spark.sql("select 1+1;")
        self.df_joined = self.spark.sql("select 1+1;")
        self.df_save = self.spark.sql("select 1+1;")

    @staticmethod
    def udf_market_fun_reg(last_12_month: list):
        def udf_market_fun(rows: list):
            index_map = {}
            for i, val in enumerate(last_12_month):
                index_map[val] = str(i)
                pass

            row_map = {}
            for row in rows:
                row_map[index_map[row['date_info']]] = row
                pass

            # 默认值都是一般市场
            market_cycle_type = 4
            # 计算市场
            max_avg_val = 0

            result_val = {}
            rate_list = []
            for key in row_map:
                month_seq = int(key) + 1
                row = row_map.get(key)

                vol_val = row['st_search_num']
                rank_val = row['st_rank']
                if rank_val is not None:
                    result_val[f"rank_{month_seq}"] = str(rank_val)
                    pass

                if vol_val is not None:
                    result_val[f"rank_{month_seq}_search_volume"] = str(vol_val)
                    pass

                # 计算最近6个月的月搜索量增长率
                if month_seq < 6:
                    last_month_row = row_map.get(str(int(key) + 1))
                    if last_month_row is not None:
                        last_vol_val = last_month_row['st_search_num']
                        if last_vol_val is not None and vol_val is not None:
                            # 增长率=(本月-上月)/上月
                            rate = round((float(vol_val) - float(last_vol_val)) / float(last_vol_val), 4)
                            result_val[f"month_growth_rate_{month_seq}"] = str(rate)
                            rate_list.append(rate)
                        pass

                value = [row['st_search_num'] for row in rows]
                # 平均值
                avg_val = round(np.sum(value) / 12, 4)
                # 最大值
                max_val = np.max(value)
                max_avg_val = round((max_val - avg_val) / avg_val, 4)

                if max_avg_val > 0.8:
                    if len([rate for rate in rate_list if rate >= 0.05]) >= 5:
                        # 5 季节性市场+持续增长市场
                        market_cycle_type = 5
                    elif len([rate for rate in rate_list if rate <= -0.05]) >= 5:
                        # 6 季节性市场+持续衰退市场
                        market_cycle_type = 6
                    else:
                        # 1 季节性市场:该关键词最近12个月的 月搜索量,按照搜索量降序排列,取得最高搜索量 h1,如果(h1-平均值)/平均值>0.8,则该市场为季节性市场
                        market_cycle_type = 1
                else:
                    #  2 持续增长市场:该关键词最近6个月有5-6个月,月搜索量增长率大于5%的市场
                    if len([rate for rate in rate_list if rate >= 0.05]) >= 5:
                        market_cycle_type = 2
                    #  3:持续衰退市场:词最近6个月有5-6个月,月搜索量增长率小于-5%的市场
                    elif len([rate for rate in rate_list if rate <= -0.05]) >= 5:
                        market_cycle_type = 3
                    else:
                        market_cycle_type = 4
                pass

            result_val['market_cycle_type'] = str(market_cycle_type)
            result_val['max_avg_val'] = str(max_avg_val)
            result_val['start_month'] = str(last_12_month[0])
            result_val['end_month'] = str(last_12_month[11])

            return result_val

        return F.udf(udf_market_fun, MapType(StringType(), StringType()))

    def get_repartition_num(self):
        if DateTypes.last365day.name == self.date_type:
            return 4
        if DateTypes.month.name == self.date_type or DateTypes.month_week.name == self.date_type:
            return 3
        if DateTypes.month_old.name == self.date_type:
            return 2
        return 10

    def get_last_12_month(self):
        last_12_month = []
        for i in range(0, 12):
            last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i))
        return last_12_month

    def read_data(self):
        sql1 = f"""
            select 
                search_term
            from dwd_st_measure
            where site_name = '{self.site_name}'
            and date_type = '{self.date_type}'
            and date_info = '{self.date_info}';
        """
        self.df_st = self.spark.sql(sql1).repartition(40, 'search_term').cache()
        self.df_st.show(10, truncate=True)

        sql2 = f"""
            select 
                search_term,
                st_search_num,
                date_info,
                st_rank
            from dim_st_detail
            where site_name = '{self.site_name}'
            and date_type = '{self.date_type}'
            and date_info in ({CommonUtil.list_to_insql(self.last_12_month)})
            and st_search_num is not null;
        """
        self.df_st_detail = self.spark.sql(sql2).repartition(40, 'search_term').cache()
        self.df_st_detail.show(10, truncate=True)

        sql3 = f"""
            select 
                cast(st_key as integer) as search_term_id, 
                search_term 
            from ods_st_key 
            where site_name ='{self.site_name}';
        """
        self.df_st_id = self.spark.sql(sql3).repartition(40, 'search_term').cache()
        self.df_st_id.show(10, truncate=True)

    def handle_data(self):
        self.df_joined = self.df_st.join(
            self.df_st_detail, on='search_term', how='inner'
        )
        self.df_joined = self.df_joined.groupby('search_term').agg(
            self.udf_market_fun_reg
            (self.last_12_month)
            (F.collect_list(F.struct("date_info", "st_search_num", "st_rank"))).alias("cal_map")
        )
        self.df_save = self.df_joined.join(
            self.df_st_id, on='search_term', how='inner'
        ).select(
            F.col("search_term_id"),
            F.col("search_term_id").alias("id"),
            F.col("search_term"),
            F.col("cal_map").getField("start_month").alias("start_month"),
            F.col("cal_map").getField("end_month").alias("end_month"),
            F.col("cal_map").getField("market_cycle_type").alias("market_cycle_type"),
            F.col("cal_map").getField("rank_1").alias("rank_1"),
            F.col("cal_map").getField("rank_2").alias("rank_2"),
            F.col("cal_map").getField("rank_3").alias("rank_3"),
            F.col("cal_map").getField("rank_4").alias("rank_4"),
            F.col("cal_map").getField("rank_5").alias("rank_5"),
            F.col("cal_map").getField("rank_6").alias("rank_6"),
            F.col("cal_map").getField("rank_7").alias("rank_7"),
            F.col("cal_map").getField("rank_8").alias("rank_8"),
            F.col("cal_map").getField("rank_9").alias("rank_9"),
            F.col("cal_map").getField("rank_10").alias("rank_10"),
            F.col("cal_map").getField("rank_11").alias("rank_11"),
            F.col("cal_map").getField("rank_12").alias("rank_12"),
            F.col("cal_map").getField("rank_1_search_volume").alias("rank_1_search_volume"),
            F.col("cal_map").getField("rank_2_search_volume").alias("rank_2_search_volume"),
            F.col("cal_map").getField("rank_3_search_volume").alias("rank_3_search_volume"),
            F.col("cal_map").getField("rank_4_search_volume").alias("rank_4_search_volume"),
            F.col("cal_map").getField("rank_5_search_volume").alias("rank_5_search_volume"),
            F.col("cal_map").getField("rank_6_search_volume").alias("rank_6_search_volume"),
            F.col("cal_map").getField("rank_7_search_volume").alias("rank_7_search_volume"),
            F.col("cal_map").getField("rank_8_search_volume").alias("rank_8_search_volume"),
            F.col("cal_map").getField("rank_9_search_volume").alias("rank_9_search_volume"),
            F.col("cal_map").getField("rank_10_search_volume").alias("rank_10_search_volume"),
            F.col("cal_map").getField("rank_11_search_volume").alias("rank_11_search_volume"),
            F.col("cal_map").getField("rank_12_search_volume").alias("rank_12_search_volume"),
            F.col("cal_map").getField("month_growth_rate_1").alias("month_growth_rate_1"),
            F.col("cal_map").getField("month_growth_rate_2").alias("month_growth_rate_2"),
            F.col("cal_map").getField("month_growth_rate_3").alias("month_growth_rate_3"),
            F.col("cal_map").getField("month_growth_rate_4").alias("month_growth_rate_4"),
            F.col("cal_map").getField("month_growth_rate_5").alias("month_growth_rate_5"),
            F.col("cal_map").getField("month_growth_rate_6").alias("month_growth_rate_6"),
            F.col("cal_map").getField("max_avg_val").alias("max_avg_val"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_type).alias("date_type"),
            F.lit(self.date_info).alias("date_info")
        )

    def save_data(self):
        db_save = "dwt_st_market"
        partitions_by = ["site_name", "date_type", "date_info"]
        hdfs_path = f"/home/{SparkUtil.DEF_USE_DB}/dwt/{db_save}/site_name={self.site_name}/date_type={self.date_type}/date_info={self.date_info}"
        print(f"清除hdfs目录中.....{hdfs_path}")
        HdfsUtils.delete_file_in_folder(hdfs_path)
        print("当前存储的表名为:", db_save)
        self.df_save = self.df_save.repartition(self.get_repartition_num())
        self.df_save.write.saveAsTable(name=db_save, format='hive', mode='append', partitionBy=partitions_by)


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwtStMarket(site_name, date_type, date_info)
    obj.run()