import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))

from utils.spark_util import SparkUtil
from utils.DorisHelper import DorisHelper
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from pyspark.sql.types import StructType, StructField, IntegerType, BooleanType
from datetime import datetime, timedelta
from pyspark.sql import functions as F

class DwtAiAsinAll(object):

    def __init__(self, site_name="us", date_type="month", date_info="2024-10"):
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)

        # 近一年 date_info list
        self.last_12_month = []
        for i in range(0, 12):
            self.last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i))

        self.last_11_month = []
        for i in range(0, 11):
            self.last_11_month.append(CommonUtil.get_month_offset(self.date_info, -i))

        # 前一年 date_info list
        self.before_12_month = []
        for i in range(12, 24):
            self.before_12_month.append(CommonUtil.get_month_offset(self.date_info, -i))

        launch_time_base_date = self.spark.sql(
            f"""SELECT max(`date`) AS last_day FROM dim_date_20_to_30 WHERE year_month = '{self.date_info}'"""
        ).collect()[0]['last_day']
        launch_time_base_date = datetime.strptime(launch_time_base_date, '%Y-%m-%d')
        self.launch_time_base_date = (launch_time_base_date + timedelta(days=-360)).strftime('%Y-%m-%d')
        print('self.launch_time_base_date为：' + self.launch_time_base_date)

        self.df_base_asin = self.spark.sql(f"select 1+1;")
        schema = StructType([
            StructField("asin", F.StringType(), True),
            StructField("launch_time", F.StringType(), True)
        ])
        self.df_asin_launch_time = self.spark.createDataFrame([], schema)
        self.df_asin_bought_flag = self.spark.sql(f"select 1+1;")
        self.df_asin_bought_last_12_month = self.spark.sql(f"select 1+1;")
        self.df_asin_bought_before_12_month = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")

        self.udf_is_periodic_flag = F.udf(self.check_month_close, IntegerType())
        self.udf_is_consecutive_flag = F.udf(self.is_consecutive, BooleanType())

    @staticmethod
    def is_close_month(m1, m2):
        """判断两个月份是否相近：差值 ≤ 1 或者 (1,12)"""
        if m1 is None or m2 is None:
            return False
        if abs(m1 - m2) <= 1:
            return True
        if {m1, m2} == {1, 12}:  # 特殊跨年情况
            return True
        return False

    @staticmethod
    def check_month_close(new_list, old_list):
        """
        如果 new_list 与 old_list 中
        任意一个值相等或相近，则标记 1，否则 0
        """
        if new_list is None or old_list is None:
            return 0
        for n in new_list:
            for o in old_list:
                if DwtAiAsinAll.is_close_month(n, o):
                    return 1
        return 0

    @staticmethod
    def is_consecutive(month_list):
        if not month_list or len(month_list) <= 1:
            return True

        month_list = sorted(month_list)
        n = len(month_list)

        # 普通连续检查（不跨年）
        normal = all(month_list[i] + 1 == month_list[i + 1] for i in range(n - 1))
        # 跨年连续检查（例如 [11,12,1,2]）
        wrap = all((month_list[i] % 12) + 1 == month_list[(i + 1) % n] for i in range(n))

        return normal or wrap

    def run(self):
        self.read_data()
        self.handle_data()
        self.save_data()

    def read_data(self):
        # 读取ASIN信息库全量数据
        sql1 = f"""
        select asin from dim_ai_asin_base where site_name = '{self.site_name}'
        """
        self.df_base_asin = self.spark.sql(sqlQuery=sql1).repartition(40, 'asin').cache()
        print("ASIN信息库基础数据如下：")
        self.df_base_asin.show(10, truncate=True)

        # 读取asin的上架时间
        batch_size = 50000000
        offset = 0
        while True:
            sql2 = f"""
            select asin, asin_launch_time as launch_time from selection.{self.site_name}_asin_latest_detail 
            order by asin limit {batch_size} OFFSET {offset}
            """
            df = DorisHelper.spark_import_with_sql(self.spark, sql2).join(
                self.df_base_asin, 'asin', 'inner'
            )
            self.df_asin_launch_time = self.df_asin_launch_time.unionByName(df).distinct()
            offset += batch_size
            if df.count() == 0:
                break
            else:
                continue
        self.df_asin_launch_time = self.df_asin_launch_time.withColumn(
            'launch_time_base_date', F.lit(self.launch_time_base_date)
        ).cache()
        print(f"asin最新上架时间数据如下：{self.df_asin_launch_time.count()}")
        self.df_asin_launch_time.show(10, truncate=True)

        # 读取dwd_ai_asin_add月销标识
        sql3 = f"""
        select 
            asin, 
            asin_bought_month_flag 
        from dwd_ai_asin_add
        where site_name = '{self.site_name}'
          and date_type = '{self.date_type}'
          and date_info in ({CommonUtil.list_to_insql(self.last_11_month)})
        """
        self.df_asin_bought_flag = self.spark.sql(sqlQuery=sql3).repartition(40, 'asin').cache()
        print("dwd_ai_asin_add月销标识数据如下：")
        self.df_asin_bought_flag.show(10, truncate=True)

        # 读取asin月销数据
        sql4 = f"""
        select 
            asin, 
            asin_bought_month as bought_month, 
            date_info 
        from dwt_flow_asin
        where site_name = '{self.site_name}'
          and date_type = '{self.date_type}'
          and date_info in ({CommonUtil.list_to_insql(self.last_12_month)})
        """
        self.df_asin_bought_last_12_month = self.spark.sql(sqlQuery=sql4).join(
            self.df_base_asin, 'asin', 'inner'
        ).cache()
        print("ASIN信息库近一年月销数据如下：")
        self.df_asin_bought_last_12_month.show(10, truncate=True)

        sql5 = f"""
        select 
            asin, 
            asin_bought_month as bought_month, 
            date_info 
        from dwt_flow_asin
        where site_name = '{self.site_name}'
          and date_type = '{self.date_type}'
          and date_info in ({CommonUtil.list_to_insql(self.before_12_month)})
        """
        self.df_asin_bought_before_12_month = self.spark.sql(sqlQuery=sql5).join(
            self.df_base_asin, 'asin', 'inner'
        ).cache()
        print("ASIN信息库前一年月销数据如下：")
        self.df_asin_bought_before_12_month.show(10, truncate=True)

    def handle_data(self):
        # 判断稳定产品：近一年销量稳定不变
        self.df_asin_bought_flag = self.df_asin_bought_flag.groupBy('asin').agg(
            F.sum(F.when(F.col('asin_bought_month_flag') == 2, 1).otherwise(0)).alias('sum_flag')
        ).withColumn(
            'is_stable_flag', F.when(F.col('sum_flag') == 11, 1).otherwise(0)
        ).cache()

        # 近一年/前一年 销量最高值对应的月份
        df_max_bought_last_12_month = self.df_asin_bought_last_12_month.groupBy('asin').agg(
            F.max('bought_month').alias('max_bought_last_12_month')
        )
        self.df_asin_bought_last_12_month = self.df_asin_bought_last_12_month.join(
            df_max_bought_last_12_month, 'asin', 'left'
        ).filter(
            F.col('bought_month') == F.col('max_bought_last_12_month')
        ).withColumn(
            'month', F.split(F.col('date_info'), '-')[1].cast('int')
        ).groupBy('asin', 'max_bought_last_12_month').agg(
            F.array_sort(F.collect_list('month')).alias('max_month_last_12_month')
        ).cache()

        df_max_bought_before_12_month = self.df_asin_bought_before_12_month.groupBy('asin').agg(
            F.max('bought_month').alias('max_bought_before_12_month')
        )
        self.df_asin_bought_before_12_month = self.df_asin_bought_before_12_month.join(
            df_max_bought_before_12_month, 'asin', 'left'
        ).filter(
            F.col('bought_month') == F.col('max_bought_before_12_month')
        ).withColumn(
            'month', F.split(F.col('date_info'), '-')[1].cast('int')
        ).groupBy('asin', 'max_bought_before_12_month').agg(
            F.array_sort(F.collect_list('month')).alias('max_month_before_12_month')
        ).cache()

        # 判断周期性产品：月销高峰值的月份相同或者相近（正负误差1个月）
        # 1.开售时间不足一年的标记为非周期性产品
        # 2.若一年高峰期对应的月份大于6个月，判断这个产品为非周期性产品
        # 3.若一年高峰期对应的月份大于3个月且月份日期不连贯，判断这个产品为非周期性产品
        self.df_base_asin = self.df_base_asin.join(
            self.df_asin_launch_time, 'asin', 'left'
        ).join(
            self.df_asin_bought_last_12_month, on='asin', how='left'
        ).join(
            self.df_asin_bought_before_12_month, on='asin', how='left'
        ).withColumn(
            'is_periodic_flag',
            F.when(
                F.col('launch_time') >= F.col('launch_time_base_date'), F.lit(0)
            ).when(
                F.size('max_month_last_12_month') > 6, F.lit(0)
            ).when(
                (F.size('max_month_last_12_month') > 3) & (~self.udf_is_consecutive_flag(F.col('max_month_last_12_month'))), F.lit(0)
            ).otherwise(
                self.udf_is_periodic_flag(F.col('max_month_last_12_month'), F.col('max_month_before_12_month'))
            )
        ).withColumn(
            'is_ascending_flag',
            F.when(
                F.col('is_periodic_flag') == 1,
                F.when(
                    F.col('max_bought_last_12_month') - F.col('max_bought_before_12_month') > 0, F.lit(1)
                ).when(
                    F.col('max_bought_last_12_month') - F.col('max_bought_before_12_month') < 0, F.lit(2)
                ).when(
                    F.col('max_bought_last_12_month') - F.col('max_bought_before_12_month') == 0,
                    F.when(
                        F.size('max_month_last_12_month') - F.size('max_month_before_12_month') > 0, F.lit(1)
                    ).when(
                        F.size('max_month_last_12_month') - F.size('max_month_before_12_month') < 0, F.lit(2)
                    ).when(
                        F.size('max_month_last_12_month') - F.size('max_month_before_12_month') == 0, F.lit(3)
                    ).otherwise(F.lit(0))
                ).otherwise(F.lit(0))
            ).otherwise(F.lit(0))
        ).cache()

    def save_data(self):
        # 字段标准化
        self.df_save = self.df_base_asin.join(
            self.df_asin_bought_flag, 'asin', 'left'
        ).select(
            F.col("asin"),
            F.col("launch_time"),
            F.col("launch_time_base_date"),
            F.col("is_stable_flag"),
            F.col("max_bought_last_12_month"),
            F.array_join(F.col("max_month_last_12_month"), ",").alias("max_month_last_12_month"),
            F.col("max_bought_before_12_month"),
            F.array_join(F.col("max_month_before_12_month"), ",").alias("max_month_before_12_month"),
            F.col("is_periodic_flag"),
            F.col("is_ascending_flag"),
            F.lit(self.site_name).alias("site_name")
        ).fillna({
            'is_stable_flag': 0
        }).repartition(2).cache()

        # 数据存储
        partition_by = ["site_name"]
        hive_tb = "dwt_ai_asin_all"
        hdfs_path = CommonUtil.build_hdfs_path(
            hive_tb,
            partition_dict={
                "site_name": self.site_name
            }
        )
        HdfsUtils.delete_file_in_folder(hdfs_path)
        print(f"正在进行数据存储，当前存储的表名为：{hive_tb}，存储路径：{hdfs_path}")
        self.df_save.write.saveAsTable(name=hive_tb, format='hive', mode='append', partitionBy=partition_by)

        print("success!")


if __name__ == "__main__":
    site_name = sys.argv[1]
    date_type = sys.argv[2]
    date_info = sys.argv[3]
    handle_obj = DwtAiAsinAll(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()
