import json
import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil, DateTypes
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F, DataFrame
from pyspark.sql.types import MapType, StringType, DecimalType, IntegerType, DoubleType, ArrayType
from yswg_utils.common_udf import parse_asin_volume_str, sort_volume, get_Fba_Fee

"""
搜索词 计算=>平均长宽高 => 计算基本利润率相关数据
依赖 dwd_st_asin_measure 表 dim_asin_detail 表
输出为 Dwd_st_volume_fba
支持所有站点 日期类型

#  更新利润率
update us_aba_last_30_day
set gross_profit_fee_sea = tb_2.gross_profit_fee_sea,
	gross_profit_fee_air=tb_2.gross_profit_fee_air

from us_aba_profit_gross_last30day tb_2
where tb_2.search_term_id = id
"""


class DwdStVolumeFba(object):

    def __init__(self, site_name, date_type, date_info):
        self.site_name = site_name
        self.date_info = date_info
        self.date_type = date_type
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)

        # #  注册本地静态方法 udf 返回新函数
        # self.udf_parse_volume_reg = F.udf(self.udf_parse_volume, MapType(StringType(), DecimalType(10, 4)))
        self.spark.udf.register("udf_sort_volume", sort_volume, ArrayType(DoubleType()))
        self.udf_transfer_weight_reg = F.udf(self.udf_transfer_weight, DoubleType())
        self.udf_transfer_volume_val_reg = F.udf(self.udf_transfer_volume_val, DoubleType())
        self.udf_calc_avg_reg = F.udf(self.udf_calc_avg, DecimalType(10, 4))
        self.udf_calc_profit_reg = F.udf(self.udf_calc_profit, MapType(StringType(), StringType()))
        self.hive_tb = "dwd_st_volume_fba"

    @staticmethod
    def udf_transfer_weight(weight, weight_type):
        """
        重量及重量类型转为 克
        :param asin_volume:
        :return:
        """
        if weight is None:
            return None

        if weight_type == 'pounds':
            rate = 453.59237
        elif weight_type == 'grams':
            rate = 1
            pass
        else:
            raise Exception("weight_type 异常请检查!!")

        return round(weight * rate, 4)

    @staticmethod
    def udf_transfer_volume_val(val, asin_volume_type):
        """
        长度单位转为 cm
        :param asin_volume:
        :return:
        """
        if val is None:
            return None

        if asin_volume_type == 'inches':
            rate = 2.54
        elif asin_volume_type == 'cm':
            rate = 1
        elif asin_volume_type == 'mm':
            rate = 10
        elif asin_volume_type == 'm':
            rate = 100
        else:
            rate = 1

        return round(val * rate, 4)

    @staticmethod
    def udf_calc_avg(
            val1_col,
            val2_col,
            val3_col,
            val1_count,
            val2_count,
            val3_count
    ):
        val1_col = val1_col or 0
        val2_col = val2_col or 0
        val3_col = val3_col or 0
        val1_count = val1_count or 0
        val2_count = val2_count or 0
        val3_count = val3_count or 0

        arr = [(val1_count, val1_col), (val2_count, val2_col), (val3_count, val3_col)]

        # 从小到大排序
        arr.sort(key=lambda it: it[0], reverse=False)

        val1, count1 = arr[0]
        val2, count2 = arr[1]
        val3, count3 = arr[2]
        if (count1 == count2 and count2 == count3):
            return round((val1 + val2 + val3) / 3, 2)
        if (count2 / count3 >= 0.3):
            return round((val2 + val3) / 2, 2)
        else:
            return val3

    @staticmethod
    def get_fba_calc_dict(df_profit_config: DataFrame):
        """
        获取配置表
        :param df_profit_config:
        :return:
        """
        calc_profit_list_df = df_profit_config.where("calc_type is not null") \
            .groupby("category_first_id").agg(
            F.first(F.col("calc_type")).alias("calc_type"),
            F.first(F.col("fba_config_json")).alias("fba_config_json"),
        )

        calc_profit_dict = {}

        for row in calc_profit_list_df.collect():
            tmp = row.asDict()

            calc_profit_dict[str(tmp['category_first_id'])] = {
                "calc_type": row['calc_type'],
                "fba_config_json": json.loads(row['fba_config_json'])
            }
        return calc_profit_dict

    @staticmethod
    def udf_calc_fba_reg(config_dict: dict):
        """
        根据配置表计算fba费用
        :param config_dict:
        :return:
        """

        def udf_calc_fba(categoy_first_id: str, price):
            if price is None:
                return None
            config_row = config_dict.get(str(categoy_first_id))
            if config_row is None:
                #  默认佣金比例是0.15
                return price * 0.15
            else:
                calc_type = config_row['calc_type']
                calc_json = config_row['fba_config_json']
                if calc_type == '价格分段':
                    min = calc_json['min']
                    rate_1 = calc_json['rate_1']
                    rate_2 = calc_json['rate_2']
                    break_val = calc_json['break_val']

                    if price <= break_val:
                        return max(min, price * rate_1)
                    else:
                        return price * rate_2


                elif calc_type == '佣金分段':
                    min = calc_json['min']
                    rate_1 = calc_json['rate_1']
                    rate_2 = calc_json['rate_2']
                    break_val = calc_json['break_val']

                    if price * rate_1 <= break_val:
                        return max(min, price * rate_1)
                    else:
                        return price * rate_2


                elif calc_type == '最小限制':
                    rate = calc_json['rate']
                    min = calc_json['min']
                    return max(min, price * rate)
                    pass
                elif calc_type == '固定比率':
                    rate = calc_json['rate']
                    return price * rate
            pass

        return F.udf(udf_calc_fba, DoubleType())

    @staticmethod
    def udf_calc_profit(long, width, high, weight, tmp_cost_all_sea, price):
        longVal_before = long or 0
        width_before = width or 0
        high_before = high or 0
        weight_before = weight or 0
        tmpCost = tmp_cost_all_sea or 0
        price_val = price or 0

        fee_type_before, fba_fee_before = get_Fba_Fee(longVal_before, width_before, high_before, weight_before)

        cost_sum = tmpCost + fba_fee_before

        count = 0

        long_result = longVal_before
        width_result = width_before
        high_result = high_before
        fee_type_result = fee_type_before
        fba_fee_result = fba_fee_before
        breakFlag = False

        val_43 = 43
        if cost_sum > price_val:
            while (long_result >= val_43 and count < 4):
                tmpVal1 = long_result / 2
                tmpVal2 = high_result * 2
                breakFlag = tmpVal1 <= tmpVal2

                tmp_list = [tmpVal1, width_result, tmpVal2]
                tmp_list.sort(reverse=True)

                long_result = tmp_list[0]
                width_result = tmp_list[1]
                high_result = tmp_list[2]

                count = count + 1

                fee_type_result, fba_fee_result = get_Fba_Fee(long_result, width_result, high_result, weight_before)

                if breakFlag:
                    break

        if breakFlag or (count == 4 and long_result > val_43):
            long_result = longVal_before
            width_result = width_before
            high_result = high_before
            fee_type_result = fee_type_before
            fba_fee_result = fba_fee_before

        return {
            "long": long_result,
            "width": width_result,
            "high": high_result,
            "fee_type": fee_type_result,
            "fba_fee": fba_fee_result,
            "long_before": longVal_before,
            "width_before": width_before,
            "high_before": high_before
        }

    def get_repartition_num(self):
        """
        根据 date_type 设置文件块数
        :return:
        """
        if self.date_type == DateTypes.day.name:
            return 1
        if self.date_type == DateTypes.week.name:
            return 2
        if self.date_type == DateTypes.month.name:
            return 2
        if self.date_type == DateTypes.last30day.name:
            return 2
        return 10

    def run(self):
        #  获取 利润率相关配置表
        profit_config_sql = f"""
            select category_id,
                   category_first_id,
                   cost,
                   avg_cost,
                   calc_type,
                   fba_config_json,
                   adv,
                   return_ratio / 100 as return_ratio
            from dim_profit_config
            where site_name = 'us'
"""
        df_profit_config = self.spark.sql(profit_config_sql).cache()
        calc_profit_dict = self.get_fba_calc_dict(df_profit_config)

        #  获取搜索词相关信息
        sql = f"""
select dsam.search_term,
	   osk.st_key                 as search_term_id,
	   dsam.asin,
       sorted_volume[0]           as long,
       sorted_volume[1]           as width,
       sorted_volume[2]           as height,
	   asin_volume_type,
	   asin_weight,
	   asin_weight_type,
	   st_bsr_cate_1_id_new       as category_first_id,
	   st_bsr_cate_current_id_new as category_id,
	   asin_price
from (
         select search_term,
                asin
         from dwd_st_asin_measure
         where date_type = '{CommonUtil.get_rel_date_type('dwd_st_asin_measure', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dsam
         left join
     (
         select search_term,
                st_bsr_cate_1_id_new,
                st_bsr_cate_current_id_new
         from dim_st_detail
     where date_type = '{CommonUtil.get_rel_date_type('dim_st_detail', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dsd on dsd.search_term = dsam.search_term
         left join
        -- 使用 事实表的四分位价格
     (
         select search_term,
                round(st_price_avg, 6 ) as asin_price
         from dwd_st_measure
             where date_type = '{CommonUtil.get_rel_date_type('dwd_st_measure', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dad on dad.search_term = dsam.search_term
         left join
     (
         select asin,
                asin_weight,
                asin_weight_type,
                udf_sort_volume(asin_length,asin_width,asin_height) sorted_volume,
                asin_volume_type
         from dim_asin_stable_info
         where site_name = '{self.site_name}'
     ) dssi on dssi.asin = dsam.asin
         inner join
     (
         select search_term,
                st_key
         from ods_st_key
         where site_name = '{self.site_name}'
     ) osk on osk.search_term = dsam.search_term
                """
        print("======================查询sql如下======================")
        print(sql)
        df_all = self.spark.sql(sql)

        # 长宽高转换
        df_all = df_all.withColumn("long", self.udf_transfer_volume_val_reg(F.col("long"), F.col("asin_volume_type")))
        df_all = df_all.withColumn("width", self.udf_transfer_volume_val_reg(F.col("width"), F.col("asin_volume_type")))
        df_all = df_all.withColumn("height", self.udf_transfer_volume_val_reg(F.col("height"), F.col("asin_volume_type")))
        df_all = df_all.withColumn("weight", self.udf_transfer_weight_reg(F.col("asin_weight"), F.col("asin_weight_type")))

        st_volume_info = df_all.groupBy("search_term_id") \
            .agg(
            F.first("search_term").alias("search_term"),
            F.first("category_first_id").alias("category_first_id"),
            F.first("category_id").alias("category_id"),
            F.avg("asin_price").alias("price"),
            F.avg(F.col('weight')).alias("weight"),
            F.avg(F.col('long')).alias("long"),
            F.avg(F.col('width')).alias("width"),
            F.avg(F.col('height')).alias("high"),
        )

        # 平均成本比例
        def_cost_rate = df_profit_config.select(F.avg("cost").cast(DecimalType(10, 3)).alias("cost_rate")).first()[0]
        # 平均广告费
        def_adv = df_profit_config.select(F.avg("adv").cast(DecimalType(10, 3)).alias("adv")).first()[0]
        # 平均退款率
        def_return_ratio = df_profit_config.select(F.avg("return_ratio").cast(DecimalType(10, 3)).alias("return_ratio")).first()[0]
        # 头程空运 运费比例
        freight_air_rate = 8.55
        # 头程海运 运费比例
        freight_sea_rate = 2.06

        #  用一级分类进行关联的数据
        df_category_first_id_config = df_profit_config.groupby("category_first_id").agg(
            F.max("adv").alias("adv"),
            F.max("return_ratio").alias("return_ratio"),
        )

        #  用 categoy_id 进行配置计算的数据
        df_category_id_config = df_profit_config.groupby("category_id").agg(
            F.max("cost").alias("cost_rate")
        )

        df_save = st_volume_info \
            .join(df_category_first_id_config, ['category_first_id'], "left") \
            .join(df_category_id_config, ['category_id'], "left")

        df_save = df_save.select(
            st_volume_info["search_term"],
            st_volume_info["search_term_id"],
            st_volume_info["category_first_id"],
            st_volume_info["category_id"],
            st_volume_info["price"],
            st_volume_info["long"],
            st_volume_info["width"],
            st_volume_info["high"],
            st_volume_info["weight"],
            df_category_first_id_config["return_ratio"],
            df_category_first_id_config["adv"],
            df_category_id_config['cost_rate']
        )
        df_save = df_save.fillna({
            # 无分类 填充退款率 佣金
            "return_ratio": float(def_return_ratio),
            "cost_rate": float(def_cost_rate),
            "adv": float(def_adv),
        }).cache()

        # 计算公式
        df_save = df_save.withColumn("referral_fee",
                                     F.round(self.udf_calc_fba_reg(calc_profit_dict)(F.col("category_first_id"), F.col("price")), 4))
        # 头程 海运
        df_save = df_save.withColumn("ocean_freight", F.expr(f"weight * {freight_sea_rate} /1000"))
        # 头程 空运
        df_save = df_save.withColumn("air_delivery_fee", F.expr(f"weight * {freight_air_rate} /1000"))
        # 运营费固定 平均售价 * 5%
        df_save = df_save.withColumn("operating_costs", F.expr(f"price * 0.05"))
        # 成本
        df_save = df_save.withColumn("costs", F.expr(f" price * cost_rate "))
        # 广告占比 adv
        df_save = df_save.withColumn("advertise", F.expr(f" price * adv  "))
        # 退款率 * 价格 即为退款额
        df_save = df_save.withColumn("return_ratio", F.expr(f" price * return_ratio "))
        # 除了fba之外的所有的费用
        df_save = df_save.withColumn("tmp_cost_all_sea",
                                     F.expr("ocean_freight  + referral_fee  + return_ratio  + costs + advertise + operating_costs"))

        # 计算利润率
        df_save = df_save.withColumn("tmp_row", self.udf_calc_profit_reg(
            F.col("long"),
            F.col("width"),
            F.col("high"),
            F.col("weight"),
            F.col("tmp_cost_all_sea"),
            F.col("price"),
        ))

        df_save = df_save.withColumn("fba_fee", F.col("tmp_row").getField("fba_fee"))

        df_save = df_save.withColumn("gross_profit_fee_sea",
                                     F.expr(
                                         "(price-(ocean_freight +referral_fee+return_ratio +costs+advertise+operating_costs + fba_fee))/price")
                                     .cast(DecimalType(10, 3)))

        df_save = df_save.withColumn("gross_profit_fee_air",
                                     F.expr(
                                         "(price-(air_delivery_fee +referral_fee+return_ratio +costs+advertise+operating_costs + fba_fee))/price")
                                     .cast(DecimalType(10, 3)))

        df_save = df_save.select(
            F.col("search_term"),
            F.col("search_term_id"),
            st_volume_info["category_first_id"],
            st_volume_info["category_id"],
            F.col("weight").cast(DecimalType(10, 3)),
            F.col("price").cast(DecimalType(10, 3)),
            F.col("referral_fee").cast(DecimalType(10, 3)),

            # 长宽高之前
            F.col("tmp_row").getField("long_before").cast(DecimalType(10, 3)).alias("long_before"),
            F.col("tmp_row").getField("width_before").cast(DecimalType(10, 3)).alias("width_before"),
            F.col("tmp_row").getField("high_before").cast(DecimalType(10, 3)).alias("high_before"),
            # 计算后
            F.col("tmp_row").getField("long").cast(DecimalType(10, 3)).alias("longs"),
            F.col("tmp_row").getField("width").cast(DecimalType(10, 3)).alias("width"),
            F.col("tmp_row").getField("high").cast(DecimalType(10, 3)).alias("high"),
            F.col("tmp_row").getField("fba_fee").cast(DecimalType(10, 3)).alias("fba_fee"),
            F.col("tmp_row").getField("fee_type").cast(IntegerType()).alias("fee_type"),

            F.col("return_ratio"),
            F.col("ocean_freight"),
            F.col("air_delivery_fee"),
            F.col("operating_costs"),
            F.col("costs"),
            F.col("advertise"),
            F.col("gross_profit_fee_sea"),
            F.col("gross_profit_fee_air"),

            F.col("cost_rate"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_type).alias("date_type"),
            F.lit(self.date_info).alias("date_info")
        ).where("weight < 10000000")

        # 分区数量调整
        df_save = df_save.repartition(self.get_repartition_num())
        partition_dict = {
            "site_name": self.site_name,
            "date_type": self.date_type,
            "date_info": self.date_info,
        }

        df_save = CommonUtil.format_df_with_template(self.spark, df_save, self.hive_tb, True)
        hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict)
        HdfsUtils.delete_hdfs_file(hdfs_path)
        partition_by = list(partition_dict.keys())
        print(f"当前存储的表名为::self.hive_tb,分区为{partition_by}", )
        df_save.write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=partition_by)
        print("success")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwdStVolumeFba(site_name, date_type, date_info)
    obj.run()