import os
import sys
import re

sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录

from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql import DataFrame
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from utils.common_util import CommonUtil


class DimAsinRelatedTraffic(object):

    def __init__(self, site_name, date_type, date_info):
        super().__init__()
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        self.hive_tb = f'dim_asin_related_traffic'
        self.partition_dict = {
            "site_name": site_name,
            "date_type": date_type,
            "date_info": date_info
        }
        self.hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict=self.partition_dict)
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.partitions_by = ['site_name', 'date_type', 'date_info']

        self.df_asin_detail = self.spark.sql(f"select 1+1;")
        self.df_self_asin_detail = self.spark.sql(f"select 1+1;")
        self.df_result_list_json = self.spark.sql(f"select 1+1;")
        self.df_together_asin = self.spark.sql(f"select 1+1;")
        self.df_sp_initial_seen_asins_json = self.spark.sql(f"select 1+1;")
        self.df_sp_4stars_initial_seen_asins_json = self.spark.sql(f"select 1+1;")
        self.df_sp_delivery_initial_seen_asins_json = self.spark.sql(f"select 1+1;")
        self.df_compare_similar_asin_json = self.spark.sql(f"select 1+1;")
        self.df_bundles_this_asins_json = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")

        self.u_categorize_flow = F.udf(self.categorize_flow, StringType())
        self.u_merge_df = F.udf(self.merge_df, StringType())
        self.u_repair_json = F.udf(self.repair_json, StringType())

    @staticmethod
    def repair_json(json_str):
        """修复指定字段的数组格式"""
        if not json_str:
            return json_str

        # 匹配三种情况:1) 已格式化的数组 2) 引号包裹的字符串 3) 无引号的值
        pattern = re.compile(
            r'("Brand in this category on Amazon"\s*:\s*)(\[[^\]]+\]|"([^"]+)"|([^,{}"]+))'
        )

        def replace_func(m):
            # 如果已经是数组格式(group(2)以[开头),直接返回
            if m.group(2).startswith('['):
                return m.group(0)  # 返回整个匹配,不做修改

            # 处理字符串值或无引号值
            raw_value = m.group(3) or m.group(4)
            items = [v.strip() for v in raw_value.split(",") if v.strip()]
            return f'{m.group(1)}["{""",""".join(items)}"]'

        return pattern.sub(replace_func, json_str)

    @staticmethod
    def merge_df(col1, col2):
        if not col1 or col1.strip() == "":
            return col2 if (col2 and col2.strip()) else None
        if not col2 or col2.strip() == "":
            return col1 if (col1 and col1.strip()) else None

        list1 = list(set(x.strip() for x in col1.split(",") if x.strip()))
        list2 = list(set(x.strip() for x in col2.split(",") if x.strip()))
        combined = list(set(list1 + list2))
        return ",".join(combined) if combined else None

    @staticmethod
    def categorize_flow(key):
        key_lower = key.lower()
        if key_lower == '4 stars and above':
            return "four_star_above"
        elif key_lower in ('brands you might like', 'more to consider from our brands', 'similar brands on amazon',
                           'exclusive items from our brands', 'more from frequently bought brands'):
            return "brand_recommendation"
        elif key_lower in ('products related to this item', 'based on your recent views', 'customer also bought',
                           'deals on related products', 'similar items in new arrivals', 'top rated similar items',
                           'compare with similar items', 'discover similar items'):
            return "similar_items"
        elif key_lower in ('customers who viewed this item also viewed', 'customers frequently viewed'):
            return "look_and_look"
        elif key_lower.startswith('customers also'):
            return "look_also_look"
        elif key_lower.startswith('what other items do customers buy after viewing this item'):
            return "look_but_bought"
        elif key_lower in ('make it a bundle', 'bundles with this item'):
            return "bundle_bought"
        elif key_lower in ('buy it with', 'frequently bought together'):
            return "combination_bought"
        elif key_lower in ('more items to explore', 'based on your recent shopping trends',
                           'related products with free delivery on eligible orders') \
                or key_lower.startswith('explore more'):
            return "more_relevant"
        elif key_lower == 'customers who bought this item also bought':
            return "bought_and_bought"
        elif key_lower == 'sponsored products related to this item':
            return "product_adv"
        elif key_lower in ('brands related to this category on amazon', 'brand in this category on amazon'):
            return "brand_adv"
        else:
            return "other"

    @staticmethod
    def other_json_handle(
            df: DataFrame,
            json_column: str,
            asin_key: str,
            output_column: str
    ) -> DataFrame:
        """
        从JSON数组字段中提取特定键的值并去重
        参数:
            df: 输入的DataFrame
            json_column: 包含JSON数组的列名
            asin_key: 要从JSON对象中提取的键名
            output_column: 输出结果的列名
        返回:
            包含去重后值的DataFrame(只有一列)
        """
        return df.withColumn(
            'json_array', F.from_json(F.col(json_column), ArrayType(MapType(StringType(), StringType())))
        ).withColumn(
            "exploded_item", F.explode("json_array")
        ).withColumn(
            "flow_asin", F.col(f"exploded_item.{asin_key}")
        ).filter(
            F.col('flow_asin').isNotNull() & (F.col('flow_asin') != "") & (F.length(F.col('flow_asin')) == 10)
        ).groupBy("asin").agg(
            F.concat_ws(",", F.collect_set("flow_asin")).alias(f"{output_column}")
        )

    def read_data(self):
        print("读取ods_asin_detail中流量相关数据")
        sql = f"""
        select 
            asin, 
            together_asin, 
            sp_initial_seen_asins_json, 
            sp_4stars_initial_seen_asins_json, 
            sp_delivery_initial_seen_asins_json, 
            compare_similar_asin_json, 
            result_list_json, 
            bundles_this_asins_json, 
            updated_at 
        from ods_asin_related_traffic 
        where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info='{self.date_info}' and asin is not null;
        """
        self.df_asin_detail = self.spark.sql(sqlQuery=sql)

        print("读取ods_self_asin_related_traffic数据")
        sql = f"""
        select 
            asin, 
            together_asin, 
            sp_initial_seen_asins_json, 
            sp_4stars_initial_seen_asins_json, 
            sp_delivery_initial_seen_asins_json, 
            compare_similar_asin_json, 
            result_list_json, 
            null as bundles_this_asins_json, 
            updated_at 
        from ods_self_asin_related_traffic where site_name='{self.site_name}' and asin is not null;
        """
        self.df_self_asin_detail = self.spark.sql(sqlQuery=sql)

        # 合并去重
        window = Window.partitionBy(['asin']).orderBy(self.df_asin_detail.updated_at.desc_nulls_last())
        self.df_asin_detail = self.df_asin_detail.unionByName(
            self.df_self_asin_detail, allowMissingColumns=False
        ).withColumn(
            "dt_rank", F.row_number().over(window=window)
        ).filter("dt_rank=1").drop("updated_at", "dt_rank").cache()
        print("详情数据如下:")
        self.df_asin_detail.show(10, True)

    # 处理result_list_json字段
    def handle_result_list_json(self):
        json_schema = ArrayType(MapType(StringType(), ArrayType(StringType())))
        self.df_result_list_json = self.df_asin_detail.filter(
            F.col('result_list_json').isNotNull()
        ).select(
            'asin', F.from_json(self.u_repair_json(F.col('result_list_json')), json_schema).alias('parsed_json')
        ).withColumn(
            "kv", F.explode("parsed_json")
        ).select(
            "asin", F.explode("kv").alias("key", "value")
        ).withColumn(
            "category", self.u_categorize_flow(F.col("key"))
        ).filter(
            F.col("category") != "other"
        ).withColumn(
            "distinct_values", F.array_distinct("value")
        ).filter(
            F.expr("size(distinct_values) > 0")
        ).select(
            'asin', 'category', 'distinct_values'
        ).groupBy(["asin", "category"]).agg(
            F.concat_ws(",", F.array_distinct(F.flatten(F.collect_list("distinct_values")))).alias("values")
        ).groupBy("asin") \
            .pivot("category") \
            .agg(F.first("values")) \
            .cache()
        print("处理result_list_json字段结果如下:")
        self.df_result_list_json.show(10, True)

    # 处理其他流量字段
    def handle_other_field(self):
        # 处理sp_initial_seen_asins_json字段
        self.df_sp_initial_seen_asins_json = self.df_asin_detail\
            .select('asin', 'sp_initial_seen_asins_json')\
            .filter(F.col('sp_initial_seen_asins_json').isNotNull())
        self.df_sp_initial_seen_asins_json = self.other_json_handle(
            df=self.df_sp_initial_seen_asins_json,
            json_column='sp_initial_seen_asins_json',
            asin_key='seen_asins',
            output_column='similar_items'
        ).cache()
        print("处理sp_initial_seen_asins_json字段结果如下:")
        self.df_sp_initial_seen_asins_json.show(10, True)

        # 处理sp_4stars_initial_seen_asins_json字段
        self.df_sp_4stars_initial_seen_asins_json = self.df_asin_detail\
            .select('asin', 'sp_4stars_initial_seen_asins_json')\
            .filter(F.col('sp_4stars_initial_seen_asins_json').isNotNull())
        self.df_sp_4stars_initial_seen_asins_json = self.other_json_handle(
            df=self.df_sp_4stars_initial_seen_asins_json,
            json_column='sp_4stars_initial_seen_asins_json',
            asin_key='seen_asins',
            output_column='four_star_above'
        ).cache()
        print("处理sp_4stars_initial_seen_asins_json字段结果如下:")
        self.df_sp_4stars_initial_seen_asins_json.show(10, True)

        # 处理sp_delivery_initial_seen_asins_json字段
        self.df_sp_delivery_initial_seen_asins_json = self.df_asin_detail\
            .select('asin', 'sp_delivery_initial_seen_asins_json')\
            .filter(F.col('sp_delivery_initial_seen_asins_json').isNotNull())
        self.df_sp_delivery_initial_seen_asins_json = self.other_json_handle(
            df=self.df_sp_delivery_initial_seen_asins_json,
            json_column='sp_delivery_initial_seen_asins_json',
            asin_key='seen_asins',
            output_column='more_relevant'
        ).cache()
        print("处理sp_delivery_initial_seen_asins_json字段结果如下:")
        self.df_sp_delivery_initial_seen_asins_json.show(10, True)

        # 处理compare_similar_asin_json字段
        self.df_compare_similar_asin_json = self.df_asin_detail\
            .select('asin', 'compare_similar_asin_json')\
            .filter(F.col('compare_similar_asin_json').isNotNull())
        self.df_compare_similar_asin_json = self.other_json_handle(
            df=self.df_compare_similar_asin_json,
            json_column='compare_similar_asin_json',
            asin_key='compare_asin',
            output_column='similar_items'
        ).cache()
        print("处理compare_similar_asin_json字段结果如下:")
        self.df_compare_similar_asin_json.show(10, True)

        # 处理bundles_this_asins_json字段
        self.df_bundles_this_asins_json = self.df_asin_detail\
            .select('asin', 'bundles_this_asins_json')\
            .filter(F.col('bundles_this_asins_json').isNotNull())
        self.df_bundles_this_asins_json = self.other_json_handle(
            df=self.df_bundles_this_asins_json,
            json_column='bundles_this_asins_json',
            asin_key='bundles_Asins',
            output_column='bundle_bought'
        ).cache()
        print("处理bundles_this_asins_json字段结果如下:")
        self.df_bundles_this_asins_json.show(10, True)

        # 处理together_asin字段
        self.df_together_asin = self.df_asin_detail.select('asin', 'together_asin').filter(
            F.col('together_asin').isNotNull()
        ).withColumnRenamed(
            'together_asin', 'combination_bought'
        ).cache()
        print("处理together_asin字段结果如下:")
        self.df_together_asin.show(10, True)

    # 合并所有df
    def handle_merge_df(self):
        all_merge_df = [self.df_together_asin, self.df_sp_initial_seen_asins_json,
                        self.df_sp_4stars_initial_seen_asins_json, self.df_sp_delivery_initial_seen_asins_json,
                        self.df_compare_similar_asin_json, self.df_bundles_this_asins_json]
        main_df = self.df_result_list_json
        for df in all_merge_df:
            for col in set(df.columns) - {"asin"}:
                if col in main_df.columns:
                    df = df.withColumnRenamed(col, f'{col}_tmp')
                    main_df = main_df.join(df, "asin", "full") \
                        .withColumn(col, self.u_merge_df(F.col(col), F.col(f"{col}_tmp"))) \
                        .drop(f"{col}_tmp")
                else:
                    main_df = main_df.join(df, "asin", "full")

        self.df_save = main_df.cache()
        print("最终合并结果如下:")
        self.df_save.show(10, True)

        self.df_asin_detail.unpersist()
        self.df_result_list_json.unpersist()
        self.df_together_asin.unpersist()
        self.df_sp_initial_seen_asins_json.unpersist()
        self.df_sp_4stars_initial_seen_asins_json.unpersist()
        self.df_sp_delivery_initial_seen_asins_json.unpersist()
        self.df_compare_similar_asin_json.unpersist()
        self.df_bundles_this_asins_json.unpersist()

    # 数据落盘
    def save_data(self):
        # 确保df字段与hive表字段结构统一
        hive_tb_cols = [f.name for f in self.spark.table(f"{self.hive_tb}").schema]
        for col in hive_tb_cols:
            if col not in self.df_save.columns:
                self.df_save = self.df_save.withColumn(col, F.lit(None))

        # 分区字段处理
        self.df_save = self.df_save.withColumn(
            'site_name', F.lit(self.site_name)
        ).withColumn(
            'date_type', F.lit(self.date_type)
        ).withColumn(
            'date_info', F.lit(self.date_info)
        ).select(*hive_tb_cols).replace('', None)

        print(f"清除hdfs目录中:{self.hdfs_path}")
        HdfsUtils.delete_file_in_folder(self.hdfs_path)
        print(f"当前存储的表名为:{self.hive_tb},分区为:{self.partitions_by}")
        self.df_save.repartition(40).write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=self.partitions_by)
        print("success")

    def run(self):
        # 读取数据
        self.read_data()
        # 处理result_list_json字段
        self.handle_result_list_json()
        # 处理其他流量字段
        self.handle_other_field()
        # 合并所有df
        self.handle_merge_df()
        # 数据落盘
        self.save_data()


if __name__ == '__main__':
    site_name = sys.argv[1]
    date_type = sys.argv[2]
    date_info = sys.argv[3]
    handle_obj = DimAsinRelatedTraffic(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()