import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录

from pyspark.sql import functions as F
from utils.spark_util import SparkUtil
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils


class DwtAsinRelatedTraffic(object):

    def __init__(self, site_name, date_type, date_info):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        self.hive_tb = f'dwt_asin_related_traffic'
        self.partition_dict = {
            "site_name": site_name,
            "date_type": date_type,
            "date_info": date_info
        }
        self.hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict=self.partition_dict)
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.partitions_by = ['site_name', 'date_type', 'date_info']

        self.df_dim_asin_related_traffic = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")

        self.col_num_index = {
            "four_star_above": 1,
            "brand_recommendation": 2,
            "similar_items": 3,
            "look_and_look": 4,
            "look_also_look": 5,
            "look_but_bought": 6,
            "bundle_bought": 7,
            "combination_bought": 8,
            "more_relevant": 9,
            "bought_and_bought": 10,
            "product_adv": 11,
            "brand_adv": 12
        }

    def read_data(self):
        print("读取dim_asin_related_traffic流量数据")
        sql = f"""
        select 
            asin, 
            four_star_above, 
            brand_recommendation, 
            similar_items, 
            look_and_look, 
            look_also_look, 
            look_but_bought, 
            bundle_bought, 
            combination_bought, 
            more_relevant, 
            bought_and_bought, 
            product_adv, 
            brand_adv 
        from dim_asin_related_traffic where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info='{self.date_info}';
        """
        self.df_dim_asin_related_traffic = self.spark.sql(sqlQuery=sql).cache()
        print("dim_asin_related_traffic数据如下:")
        self.df_dim_asin_related_traffic.show(10, True)

    # 聚合计算
    def handle_data(self):
        cols = [col for col in self.df_dim_asin_related_traffic.columns if col != 'asin']

        for col in cols:
            self.df_dim_asin_related_traffic = self.df_dim_asin_related_traffic.withColumn(
                col, F.concat_ws(",", F.filter(F.split(F.col(col), ","), lambda x: (F.length(F.trim(x)) == 10)))
            ).withColumn(
                col, F.when(F.col(col) == "", None).otherwise(F.col(col))
            )

        # 将所有类型下的关联流量asin拼接
        self.df_dim_asin_related_traffic = self.df_dim_asin_related_traffic.withColumn(
            "related_asin", F.concat_ws(",", *[F.col(col) for col in cols])
        )

        # 根据map映射 生成与流量asin数量相等的编号列
        for col in cols:
            num = self.col_num_index[col]
            self.df_dim_asin_related_traffic = self.df_dim_asin_related_traffic.withColumn(
                f"{col}_num", F.when(F.col(col).isNull(), F.lit(None))
                    .otherwise(F.concat_ws(",", F.array_repeat(F.lit(num), F.size(F.split(F.col(col), ",")))))
            )

        # 将所有编号列进行拼接
        self.df_dim_asin_related_traffic = self.df_dim_asin_related_traffic.withColumn(
            "related_type", F.concat_ws(",", *[F.col(f"{col}_num") for col in cols])
        )

    # 数据落盘
    def save_data(self):
        self.df_save = self.df_dim_asin_related_traffic.select(
            'asin',
            'related_asin',
            'related_type',
            F.lit(self.site_name).alias('site_name'),
            F.lit(self.date_type).alias('date_type'),
            F.lit(self.date_info).alias('date_info')
        )
        print(f"清除hdfs目录中:{self.hdfs_path}")
        HdfsUtils.delete_file_in_folder(self.hdfs_path)
        print(f"当前存储的表名为:{self.hive_tb},分区为:{self.partitions_by}")
        self.df_save.repartition(40).write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=self.partitions_by)
        print("success")

    def run(self):
        # 读取数据
        self.read_data()
        # 聚合计算
        self.handle_data()
        # 数据落盘
        self.save_data()


if __name__ == '__main__':
    site_name = sys.argv[1]
    date_type = sys.argv[2]
    date_info = sys.argv[3]
    handle_obj = DwtAsinRelatedTraffic(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()