import os
import sys

from pyspark.storagelevel import StorageLevel
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql.types import StringType
# 分组排序的udf窗口函数
from pyspark.sql.window import Window
from pyspark.sql import functions as F


class DwdAsinAndStCounts(Templates):

    def __init__(self, site_name="us", date_type="week", date_info="2022-1"):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        self.db_save_asin = f"dwd_asin_counts"
        self.db_save_st = f"dwd_st_counts"
        self.spark = self.create_spark_object(app_name=f"{self.db_save_asin}, {self.db_save_st}: {self.site_name}, {self.date_type}, {self.date_info}")
        self.df_date = self.get_year_week_tuple()
        self.df_save_asin = self.spark.sql(f"select 1+1;")
        self.df_save_st = self.spark.sql(f"select 1+1;")
        self.df_save_asin_std = self.spark.sql(f"select * from {self.db_save_asin} limit 0;")
        self.df_save_st_std = self.spark.sql(f"select * from {self.db_save_st} limit 0;")
        self.df_st_asin = self.spark.sql(f"select 1+1;")
        self.df_st_ao_val = self.spark.sql(f"select 1+1;")
        self.partitions_by = ['site_name', 'date_type', 'date_info']
        self.u_ao_val_rate = self.spark.udf.register("u_ao_val_rate", self.udf_ao_val_rate, StringType())

    @staticmethod
    def udf_ao_val_rate(st_ao_rank, st2_counts):
        if st_ao_rank < st2_counts / 100 * 1:
            return "top_1"
        elif st_ao_rank < st2_counts / 100 * 2:
            return "top_2"
        elif st_ao_rank < st2_counts / 100 * 3:
            return "top_3"
        elif st_ao_rank < st2_counts / 100 * 4:
            return "top_4"
        elif st_ao_rank < st2_counts / 100 * 5:
            return "top_5"
        elif st_ao_rank < st2_counts / 100 * 10:
            return "top_6~10"
        elif st_ao_rank < st2_counts / 100 * 20:
            return "top_11~20"
        elif st_ao_rank < st2_counts / 100 * 30:
            return "top_21~30"
        elif st_ao_rank < st2_counts / 100 * 40:
            return "top_31~40"
        elif st_ao_rank < st2_counts / 100 * 50:
            return "top_41~50"
        elif st_ao_rank < st2_counts / 100 * 60:
            return "top_51~60"
        elif st_ao_rank < st2_counts / 100 * 70:
            return "top_61~70"
        elif st_ao_rank < st2_counts / 100 * 80:
            return "top_71~80"
        elif st_ao_rank < st2_counts / 100 * 90:
            return "top_81~90"
        elif st_ao_rank <= st2_counts / 100 * 100:
            return "top_91~100"
        else:
            return "top_xxx"

    def read_data(self):
        print("1. 读取dim_st_asin_info表")
        sql = f"select * from dim_st_asin_info where site_name='{self.site_name}' and date_type='week' and date_info in {self.year_week_tuple};"""
        print("sql:", sql)
        self.df_st_asin = self.spark.sql(sqlQuery=sql).cache()
        self.df_st_asin.show(10, truncate=False)
        self.df_st_asin = self.df_st_asin.drop_duplicates(["asin", "search_term", "data_type"])

    def handle_data(self):
        self.df_save_asin = self.handle_data_counts(cal_type="asin")
        self.df_save_st = self.handle_data_counts(cal_type="st")
        self.handle_ao_val()  # 计算asin_ao_val和st_ao_val
        self.df_save_asin = self.df_save_asin.withColumn("date_type", F.lit(self.date_type))
        self.df_save_asin = self.df_save_asin.withColumn("date_info", F.lit(self.date_info))
        self.df_save_st = self.df_save_st.withColumn("date_type", F.lit(self.date_type))
        self.df_save_st = self.df_save_st.withColumn("date_info", F.lit(self.date_info))
        # self.df_save_asin.show(10, truncate=False)
        # self.df_save_st.show(10, truncate=False)
        # quit()

    def handle_data_counts(self, cal_type="asin"):
        print(f"2. 计算{cal_type}_counts")
        cal_type_complete = "search_term" if cal_type == "st" else cal_type
        self.df_st_asin = self.df_st_asin.withColumn(
            f"{cal_type}_data_type",
            F.concat(F.lit(f"{cal_type}_"), self.df_st_asin.data_type, F.lit(f"_counts"))
        )
        df = self.df_st_asin.groupby([f'{cal_type_complete}']).\
            pivot(f"{cal_type}_data_type").count()
        if cal_type == "asin":
            df = df.unionByName(self.df_save_asin_std, allowMissingColumns=True)
        else:
            self.df_save_st_std = self.df_save_st_std.drop("st_ao_val")
            df = df.unionByName(self.df_save_st_std, allowMissingColumns=True)

        df = df.fillna(0)
        # df.show(10, truncate=False)
        df = df.withColumn(
            f"{cal_type}_sb_counts",
            df[f"{cal_type}_sb1_counts"] + df[f"{cal_type}_sb2_counts"] + df[f"{cal_type}_sb3_counts"]
        )
        df = df.withColumn(
            f"{cal_type}_adv_counts",
            df[f"{cal_type}_sb_counts"] + df[f"{cal_type}_sp_counts"]
        )
        df = df.withColumn(f"site_name", F.lit(self.site_name))
        # df.show(10, truncate=False)
        return df

    def handle_ao_val(self):
        print("3. 计算asin_ao_val和st_ao_val")
        print("3.1 计算asin_ao_val")
        self.df_save_asin = self.df_save_asin.withColumn("asin_ao_val", self.df_save_asin.asin_adv_counts / self.df_save_asin.asin_zr_counts)
        self.df_save_asin = self.df_save_asin.fillna({"asin_ao_val": 0})
        # 选择最新一周关键词对应的asin,根据zr类型的page_rank计算ao_val(仔细考虑之后,直接去重即可)
        print("3.2 计算st_ao_val")
        # df_asin_ao_val = self.df_save_asin.select("asin", "asin_ao_val")
        # self.df_st_ao_val = self.df_st_asin. \
        #     filter("data_type='zr' and page_rank<=20"). \
        #     drop_duplicates(["search_term", "asin"])
        # self.df_st_ao_val = self.df_st_ao_val.join(df_asin_ao_val, on='asin', how='left')
        # window = Window.partitionBy(['search_term']). \
        #     orderBy(self.df_st_ao_val.asin_ao_val.desc())
        # self.df_st_ao_val = self.df_st_ao_val.withColumn("st_ao_rank", F.row_number().over(window))
        # # self.df_st_ao_val.show(50, truncate=False)
        # df_ao_val1 = self.df_st_ao_val.filter("st_ao_rank>3").groupby(['search_term']).agg({"asin_ao_val": "mean"})
        # df_ao_val1 = df_ao_val1.withColumnRenamed("avg(asin_ao_val)", "st_ao_val")
        # # df_ao_val1.show(10, truncate=False)
        # df_ao_val2 = self.df_st_ao_val.join(df_ao_val1.filter("st_ao_val=0"), on='search_term', how='inner'). \
        #     groupby(['search_term']).agg({"asin_ao_val": "mean"})
        # # df_ao_val2.show(10, truncate=False)
        # df_ao_val2 = df_ao_val2.withColumnRenamed("avg(asin_ao_val)", "st_ao_val")
        # # df_ao_val2.show(10, truncate=False)
        # df_ao_val = df_ao_val1.filter("st_ao_val>0").unionByName(df_ao_val2, allowMissingColumns=True)
        # # df_ao_val.show(10, truncate=False)
        # self.df_save_st = self.df_save_st.join(df_ao_val, on='search_term', how='left')

        # ao_val计算方式调整:取均值
        df_asin_ao = self.df_save_asin.select("asin", "asin_ao_val")
        df_st_asin = self.df_st_asin.drop_duplicates(["search_term", "asin"]).cache()
        df_st_asin = df_st_asin.join(
            df_asin_ao, on='asin', how='left'
        )
        df_st_ao = df_st_asin.groupby(['search_term']).agg(F.avg('asin_ao_val').alias("st_ao_val"))
        self.df_save_st = self.df_save_st.join(
            df_st_ao, on='search_term', how='left'
        )

        # # ao_val计算方式调整:取均值
        # df_asin_ao_val = self.df_save_asin.select("asin", "asin_ao_val")
        # self.df_st_ao_val = self.df_st_asin. \
        #     filter("data_type='zr'"). \
        #     drop_duplicates(["search_term", "asin"])
        # self.df_st_ao_val = self.df_st_ao_val.join(df_asin_ao_val, on='asin', how='left')
        # # self.df_st_ao_val.filter("search_term='agujas dermapen 36 puntas'").show(100)
        # df_ao_val = self.df_st_ao_val.groupby(['search_term']).agg({"asin_ao_val": "mean"})
        # df_ao_val = df_ao_val.withColumnRenamed("avg(asin_ao_val)", "st_ao_val")
        # self.df_save_st = self.df_save_st.join(df_ao_val, on='search_term', how='left')


        df_save_st1 = self.df_save_st.filter("st_ao_val=0")
        df_save_st1 = df_save_st1.withColumn("st_ao_val_rank", F.lit(0))
        df_save_st1 = df_save_st1.withColumn("st_ao_val_rate", F.lit("top_0"))
        df_save_st2 = self.df_save_st.filter("st_ao_val>0")
        window = Window.orderBy(df_save_st2.st_ao_val.asc())
        df_save_st2 = df_save_st2.withColumn("st_ao_val_rank", F.row_number().over(window))
        # df_save_st2.filter("st_ao_val_rank>10000").show(100, truncate=False)
        # st2_counts = df_save_st2.count()
        # df_save_st2 = df_save_st2.withColumn("st2_counts", F.lit(st2_counts))
        df_save_st2 = df_save_st2.withColumn(
            "st_ao_val_rate",
            self.u_ao_val_rate(
                "st_ao_val_rank", F.lit(df_save_st2.count())
            )
        )
        # df_save_st2.show(20, truncate=False)
        # df_save_st2.filter("st_ao_val_rate='top_xxx'").show(20, truncate=False)
        # df_save_st2.groupby(['st_ao_val_rate']).count().show(20, truncate=False)
        self.df_save_st = df_save_st1.unionByName(df_save_st2, allowMissingColumns=True)

    def save_data(self):
        self.reset_partitions(partitions_num=5)
        self.save_data_common(
            df_save=self.df_save_asin,
            db_save=self.db_save_asin,
            partitions_num=self.partitions_num,
            partitions_by=self.partitions_by
        )
        self.reset_partitions(partitions_num=1)
        self.save_data_common(
            df_save=self.df_save_st,
            db_save=self.db_save_st,
            partitions_num=self.partitions_num,
            partitions_by=self.partitions_by
        )


if __name__ == "__main__":
    site_name = sys.argv[1]  # 参数1:站点
    date_type = sys.argv[2]  # 参数2:类型:week/4_week/month/quarter
    date_info = sys.argv[3]  # 参数3:年-周/年-月/年-季, 比如: 2022-1
    handle_obj = DwdAsinAndStCounts(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()