"""
author: 方星钧(ffman)
description: 清洗6大站点对应的 “ods_brand_analytics” 的表: 字段重命名+热搜词、新出词、上升词
table_read_name: ods_brand_analytics
table_save_name: dim_st_detail
table_save_level: dim
version: 1.0
created_date: 2022-11-21
updated_date: 2022-11-22
"""


import os
import sys
import time

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from pyspark.storagelevel import StorageLevel
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql.types import StringType, IntegerType
# 分组排序的udf窗口函数
from pyspark.sql.window import Window
from pyspark.sql import functions as F


class DwtStDetail(Templates):

    def __init__(self, site_name='us', date_type="month", date_info='2022-1'):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        self.db_save = f'dim_st_detail'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}, {self.date_type}, {self.date_info}")
        # if self.date_type == '4_week':
        #     self.date_info = '2022-12-17'
        self.get_date_info_tuple()
        self.get_year_month_days_dict()
        self.date_info_last = self.get_date_info_last()
        self.date_info_last_year = self.date_info.replace(f"{self.year}", f"{int(self.year)-1}")  # 去年同期
        self.df_st_detail = self.spark.sql(f"select 1+1;")
        self.df_st_detail_last = self.spark.sql(f"select 1+1;")
        self.df_st_detail_last_year = self.spark.sql(f"select 1+1;")
        # self.df_st_detail_last_year = self.date_info.replace(f"{self.year}", f"{int(self.year)-1}")  # 去年同期
        self.df_st_detail_days = self.spark.sql(f"select 1+1;")
        self.df_st_quantity = self.spark.sql(f"select 1+1;")
        self.df_st_rate = self.spark.sql(f"select 1+1;")
        self.df_asin_history = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.partitions_num = 3
        self.reset_partitions(partitions_num=self.partitions_num)
        self.partitions_by = ['site_name', 'date_type', 'date_info']
        self.u_handle_columns_asin123 = self.spark.udf.register("u_handle_columns_asin123", self.udf_handle_columns_asin123, StringType())
        self.u_is_ascending = self.spark.udf.register("u_is_ascending", self.udf_is_ascending, IntegerType())
        self.u_is_search = self.spark.udf.register("u_is_search", self.udf_is_search, IntegerType())
        self.u_get_asin_top = self.spark.udf.register("u_get_asin_top", self.udf_get_asin_top, StringType())

    def get_date_info_last(self):
        # 获取上一个周期的日期数据
        df = self.df_date.toPandas()
        date_info_last = ''
        if self.date_type == 'day':
            date_id = tuple(df.loc[df.date == self.date_info].id)[0]
            df_loc = df.loc[df.id == (date_id - 1)]
            date_info_last = tuple(df_loc.date)[0]
        if self.date_type in ['week', 'week_old']:
            date_id = tuple(df.loc[(df[f'year_week'] == self.date_info) & (
                    df.week_day == 1)].id)[0]
            df_loc = df.loc[df.id == (date_id - 7)]
            date_info_last = tuple(df_loc[f'year_week'])[0]
        if self.date_type == '4_week':
            date_id = tuple(df.loc[(df[f'year_week'] == self.date_info) & (
                    df.week_day == 1)].id)[0]
            df_loc = df.loc[df.id == (date_id - 28)]
            date_info_last = tuple(df_loc.year_week)[0]
        if self.date_type in ['month', 'month_old']:
            if int(self.month) == 1:
                date_info_last = f"{int(self.year) - 1}-12"
            else:
                month = f"0{int(self.month) - 1}" if int(self.month) <= 10 else f"{int(self.month) - 1}"
                date_info_last = f"{self.year}-{month}"
        if self.date_type == 'last30day':
            date_id = tuple(df.loc[df.date == self.date_info].id)[0]
            df_loc = df.loc[df.id == (date_id - 30)]
            date_info_last = tuple(df_loc.date)[0]
        print("date_info_last:", date_info_last)
        return date_info_last

    @staticmethod
    def udf_handle_columns_asin123(asin):
        if len(str(asin)) == 10:
            return asin
        else:
            return None

    @staticmethod
    def udf_is_ascending(x):
        if x >= 0.5:
            return 1
        else:
            return 0

    @staticmethod
    def udf_is_search(x):
        if x >= 0.8:
            return 1
        else:
            return 0

    @staticmethod
    def udf_get_asin_top(asin1, value1, asin2, value2, asin3, value3, flag):
        """通过分享转化比大小顺序找到对应的asin顺序,从而找到bs分类id"""
        if max(value1, value2, value3) == value1:
            asin_top1 = asin1
            if max(value2, value3) == value2:
                asin_top2 = asin2
                asin_top3 = asin3
            else:
                asin_top2 = asin3
                asin_top3 = asin2
        elif max(value1, value2, value3) == value2:
            asin_top1 = asin2
            if max(value1, value3) == value1:
                asin_top2 = asin1
                asin_top3 = asin3
            else:
                asin_top2 = asin3
                asin_top3 = asin1
        else:
            asin_top1 = asin3
            if max(value1, value2) == value1:
                asin_top2 = asin1
                asin_top3 = asin2
            else:
                asin_top2 = asin2
                asin_top3 = asin1
        if flag == 1:
            return asin_top1
        elif flag == 2:
            return asin_top2
        else:
            return asin_top3

    def read_data_st_detail(self):
        print("1.1 读取ods_brand_analytics表")
        # 当前date_info
        sql = f"select * from ods_brand_analytics where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info='{self.date_info}'"
        print("sql:", sql)
        self.df_st_detail = self.spark.sql(sql).cache()
        self.df_st_detail.show(10, truncate=False)
        # 上一个周日的date_info
        try:
            sql = f"select search_term, rank as rank_last from ods_brand_analytics where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info='{self.date_info_last}'"
            print("sql:", sql)
            self.df_st_detail_last = self.spark.sql(sql).cache()
            if self.df_st_detail_last.count() == 0:
                print("date_info_last对应分区没有数据:", self.date_info_last)
                os.system(
                    rf"/mnt/run_shell/spark_shell/ods/spark_ods_brand_analytics.sh {self.site_name} {self.date_type} {self.date_info_last}")
                # 重新建立spark对象,刷新元数据对应关系
                print("重新建立spark对象,刷新元数据对应关系,等待10s,暂时发现不可行")
                self.spark = self.create_spark_object(
                    app_name=f"{self.db_save}: {self.site_name}, {self.date_type}, {self.date_info}")
                time.sleep(10)
                self.df_st_detail_last = self.spark.sql(sql).cache()
                self.df_st_detail_last.show(10, truncate=False)
                quit()
        except Exception as e:
            print("error:", e)

        # 当前周期的date_info_tuple
        sql = f"select * from ods_brand_analytics where site_name='{self.site_name}' and date_type='day' and date_info in {self.date_info_tuple}"
        print("sql:", sql)
        self.df_st_detail_days = self.spark.sql(sql).cache()
        self.df_st_detail_days = self.df_st_detail_days.drop_duplicates(['search_term', 'date_info'])
        self.df_st_detail_days = self.df_st_detail_days.groupby(['search_term']).agg(
            F.count('date_info').alias('st_appear_history_counts'))
        # self.df_st_detail_days.filter('search_term="nan"').show(10, truncate=False)
        st_history_counts = 28 if self.date_type == '4_week' else 1 if self.date_type == 'day' else len(self.date_info_tuple)
        # self.df_st_detail_days = self.df_st_detail_days.withColumn('st_history_counts', F.lit(st_history_counts))
        self.df_st_detail = self.df_st_detail.withColumn('st_history_counts', F.lit(st_history_counts))
        self.df_st_detail_days.show(10, truncate=False)
        # 去年同期的date_info_last_year
        sql = f"select search_term, 0 as st_is_new_market_segment from ods_brand_analytics where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info='{self.date_info_last_year}';"
        print("sql:", sql)
        self.df_st_detail_last_year = self.spark.sql(sql).cache()
        self.df_st_detail_last_year.show(10, truncate=False)

    def read_data_st_quantity(self):
        print("1.2 读取ods_st_quantity_being_sold表")
        if self.site_name in ['us'] and (
                (int(self.month) >= 12 and int(self.year) >= 2022) or (int(self.year) >= 2023)):
            sql = f"select search_term, quantity_being_sold as st_quantity_being_sold from ods_st_quantity_being_sold " \
                  f"where site_name='{self.site_name}'  and date_type='day' and date_info in {self.date_info_tuple};"
        else:
            sql = f"select search_term, quantity_being_sold as st_quantity_being_sold from ods_brand_analytics " \
                  f"where site_name='{self.site_name}'  and date_type='week' and date_info >= '2022-01';"
        print("sql:", sql)
        self.df_st_quantity = self.spark.sql(sqlQuery=sql).cache()
        self.df_st_quantity.show(10, truncate=False)

    def read_data_st_search_num(self):
        print("1.3 读取ods_rank_search_rate_repeat表")
        if (int(self.year) <= 2022 and int(self.month) <= 8) or int(self.year) <= 2021:
            params = f"date_info='2022-08'"
        else:
            params = f"date_info='{self.year}-{self.month}'"
        sql = f"select rank, search_num as st_search_num, rate as st_search_rate, search_sum as st_search_sum " \
              f"from ods_rank_search_rate_repeat where site_name='{self.site_name}' and date_type='month' and  {params};"
        print("sql:", sql)
        self.df_st_rate = self.spark.sql(sql).cache()
        self.df_st_rate.show(10, truncate=False)
        if self.df_st_rate.count() == 0:
            sql = f"select rank, search_num as st_search_num, rate as st_search_rate, search_sum as st_search_sum, date_info " \
                  f"from ods_rank_search_rate_repeat where site_name='{self.site_name}';"
            print("sql:", sql)
            self.df_st_rate = self.spark.sql(sql).cache()
            print("self.df_st_rate开窗前:", self.df_st_rate.count())
            window = Window.partitionBy(["rank"]).orderBy(
                self.df_st_rate.date_info.desc()
            )
            self.df_st_rate = self.df_st_rate.withColumn("date_info_rank", F.row_number().over(window=window)). \
                filter("date_info_rank=1")
            print("self.df_st_rate开窗后:", self.df_st_rate.count())
            self.df_st_rate = self.df_st_rate.drop("date_info_rank", "date_info")
            self.df_st_rate.show(10, truncate=False)

    def read_data_asin_history(self):
        print("1.4 读取dim_cal_asin_history_detail表")
        sql = f"select asin, bsr_cate_1_id, bsr_cate_current_id " \
              f"from dim_cal_asin_history_detail where site_name='{self.site_name}';"
        print("sql:", sql)
        self.df_asin_history = self.spark.sql(sql).cache()
        self.df_asin_history.show(10, truncate=False)

    def read_data(self):
        self.read_data_st_detail()
        self.read_data_st_quantity()
        self.read_data_st_search_num()
        self.read_data_asin_history()

    def handle_data(self):
        self.handle_columns_null()
        self.handle_st_click_and_conversion()
        self.handle_st_rate_and_quantity()
        self.handle_st_first()
        self.handle_st_ascending()
        self.handle_st_search()
        self.handle_st_top3_asin()
        self.handle_st_new_market_segment()
        self.handle_columns_renamed()
        self.df_save = self.df_st_detail
        self.df_save = self.df_save.drop_duplicates(['search_term'])
        print("self.df_save.columns:", self.df_save.columns)
        # self.df_save.show(10, truncate=False)
        # quit()

    def handle_columns_null(self):
        # 点击率和转换率和asin123字段脏数据处理,数值类型置为0,字符串类型置为null
        self.df_st_detail = self.df_st_detail.withColumn('click_share1', F.when(self.df_st_detail.click_share1 >= 0,
                                                                                self.df_st_detail.click_share1).otherwise(
            0.0))
        self.df_st_detail = self.df_st_detail.withColumn('click_share2', F.when(self.df_st_detail.click_share2 >= 0,
                                                                                self.df_st_detail.click_share2).otherwise(
            0.0))
        self.df_st_detail = self.df_st_detail.withColumn('click_share3', F.when(self.df_st_detail.click_share3 >= 0,
                                                                                self.df_st_detail.click_share3).otherwise(
            0.0))
        self.df_st_detail = self.df_st_detail.withColumn('conversion_share1',
                                                         F.when(self.df_st_detail.conversion_share1 >= 0,
                                                                self.df_st_detail.conversion_share1).otherwise(0.0))
        self.df_st_detail = self.df_st_detail.withColumn('conversion_share2',
                                                         F.when(self.df_st_detail.conversion_share2 >= 0,
                                                                self.df_st_detail.conversion_share2).otherwise(0.0))
        self.df_st_detail = self.df_st_detail.withColumn('conversion_share3',
                                                         F.when(self.df_st_detail.conversion_share3 >= 0,
                                                                self.df_st_detail.conversion_share3).otherwise(0.0))
        self.df_st_detail = self.df_st_detail.withColumn('asin1', self.u_handle_columns_asin123('asin1'))
        self.df_st_detail = self.df_st_detail.withColumn('asin2', self.u_handle_columns_asin123('asin2'))
        self.df_st_detail = self.df_st_detail.withColumn('asin3', self.u_handle_columns_asin123('asin3'))

    def handle_st_click_and_conversion(self):
        print("关键词的点击率和转化率求和")
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_click_share_sum",
            self.df_st_detail.click_share1 + self.df_st_detail.click_share2 + self.df_st_detail.click_share3
        ).withColumn(
            "st_conversion_share_sum",
            self.df_st_detail.conversion_share1 + self.df_st_detail.conversion_share2 + self.df_st_detail.conversion_share3
        )

    def handle_st_rate_and_quantity(self):
        print("关键词的在售商品数,搜索量,转化率,销量(月度)")
        # st_quantity_being_sold
        self.df_st_quantity = self.df_st_quantity.filter("st_quantity_being_sold > 0").groupby(['search_term']).agg(
            {"st_quantity_being_sold": "mean"}
        )
        self.df_st_quantity = self.df_st_quantity.withColumnRenamed(
            "avg(st_quantity_being_sold)", "st_quantity_being_sold"
        )
        # 关键词的搜索量,转化率,销量(月度)
        self.df_st_detail = self.df_st_detail.join(
            self.df_st_rate, on=['rank'], how='left'
        ).join(
            self.df_st_quantity, on=['search_term'], how='left'
        )

    def handle_st_first(self):
        print("新出词(当前天/周/4周/月/季度,同比前1天/周/4周/月/季度,第1次出现)")
        self.df_st_detail_last = self.df_st_detail_last.withColumn("st_is_first_text", F.lit(0))
        self.df_st_detail = self.df_st_detail.join(
            self.df_st_detail_last.select("search_term", "st_is_first_text"), on='search_term', how='left'
        )
        self.df_st_detail = self.df_st_detail.fillna(
            {"st_is_first_text": 1}
        )
        # self.df_st_detail.show(10, truncate=False)

    def handle_st_ascending(self):
        print("上升词(相邻2天/周/月/季度,上升超过50%的排名)")
        self.df_st_detail = self.df_st_detail.join(
            self.df_st_detail_last.select("search_term", "rank_last"), on='search_term', how='left'
        )
        self.df_st_detail = self.df_st_detail.na.fill({'rank_last': 0})
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_is_ascending_text_rate",
            (self.df_st_detail.rank_last - self.df_st_detail.rank) / self.df_st_detail.rank_last
        )
        self.df_st_detail = self.df_st_detail.na.fill({'st_is_ascending_text_rate': -1})
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_is_ascending_text", self.u_is_ascending(self.df_st_detail.st_is_ascending_text_rate))
        self.df_st_detail = self.df_st_detail.drop("rank_last")

    def handle_st_search(self):
        print("热搜词(历史出现占比>=80%)")
        self.df_st_detail = self.df_st_detail.join(
            self.df_st_detail_days, on='search_term', how='left'
        )
        self.df_st_detail = self.df_st_detail.fillna({'st_appear_history_counts': 0})
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_is_search_text_rate",
            self.df_st_detail[f"st_appear_history_counts"] / self.df_st_detail[f"st_history_counts"])
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_is_search_text", self.u_is_search(self.df_st_detail.st_is_search_text_rate))

    def handle_st_top3_asin(self):
        print("关键词的top3asin--匹配关键词的bsr一级分类id")
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_top_asin1",
            self.u_get_asin_top(
                "asin1", "conversion_share1",
                "asin2", "conversion_share2",
                "asin3", "conversion_share3",
                F.lit(1)
            )
        )
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_top_asin2",
            self.u_get_asin_top(
                "asin1", "conversion_share1",
                "asin2", "conversion_share2",
                "asin3", "conversion_share3",
                F.lit(2)
            )
        )
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_top_asin3",
            self.u_get_asin_top(
                "asin1", "conversion_share1",
                "asin2", "conversion_share2",
                "asin3", "conversion_share3",
                F.lit(3)
            )
        )
        df1 = self.df_st_detail.select("search_term", "st_top_asin1").\
            withColumnRenamed("st_top_asin1", "asin").withColumn("type", F.lit(1))
        df2 = self.df_st_detail.select("search_term", "st_top_asin2"). \
            withColumnRenamed("st_top_asin2", "asin").withColumn("type", F.lit(2))
        df3 = self.df_st_detail.select("search_term", "st_top_asin3"). \
            withColumnRenamed("st_top_asin3", "asin").withColumn("type", F.lit(3))
        df = df1.unionByName(df2, allowMissingColumns=True).unionByName(df3, allowMissingColumns=True)
        df = df.join(self.df_asin_history, on='asin', how="left")
        # df.show(10, truncate=False)
        # df.filter("asin='B00E4WOQU0'").show()
        window = Window.partitionBy(["search_term"]).orderBy(
            df.type.asc_nulls_last()
        )
        df = df.withColumn("type_rank", F.row_number().over(window=window)). \
            filter("type_rank=1")
        df = df.drop("type_rank", "type", "asin")
        # df.show(10, truncate=False)
        # df.filter("asin='B00E4WOQU0'").show()
        self.df_st_detail = self.df_st_detail.join(df, on="search_term", how="left")
        self.df_st_detail = self.df_st_detail.withColumnRenamed("bsr_cate_1_id", "st_bsr_cate_1_id")
        self.df_st_detail = self.df_st_detail.withColumnRenamed("bsr_cate_current_id", "st_bsr_cate_current_id")

    def handle_st_new_market_segment(self):
        print("判断关键词是否属于新细分市场,排名超过100w的关键词为否")
        self.df_st_detail = self.df_st_detail.join(
            self.df_st_detail_last_year, on=['search_term'], how='left'
        )
        self.df_st_detail = self.df_st_detail.fillna(
            {"st_is_new_market_segment": 1}
        )
        self.df_st_detail = self.df_st_detail.withColumn(
            "st_is_new_market_segment", F.when(
                self.df_st_detail.rank <= 1000000,
                self.df_st_detail.st_is_new_market_segment
            ).otherwise(F.lit(0))
        )

    def handle_columns_renamed(self):
        self.df_st_detail = self.df_st_detail.drop('id', 'created_time', 'updated_time', 'product_title1', 'product_title2', 'product_title3', 'quantity_being_sold')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('rank', 'st_rank')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('asin1', 'st_asin1')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('click_share1', 'st_click_share1')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('conversion_share1', 'st_conversion_share1')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('asin2', 'st_asin2')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('click_share2', 'st_click_share2')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('conversion_share2', 'st_conversion_share2')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('asin3', 'st_asin3')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('click_share3', 'st_click_share3')
        self.df_st_detail = self.df_st_detail.withColumnRenamed('conversion_share3', 'st_conversion_share3')
        # self.df_save.show(10, truncate=False)


if __name__ == '__main__':
    site_name = sys.argv[1]  # 参数1:站点
    date_type = sys.argv[2]  # 参数2:类型:day/week/4_week/month/quarter
    date_info = sys.argv[3]  # 参数3:年-月-日/年-周/年-月/年-季, 比如: 2022-1
    handle_obj = DwtStDetail(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()