import json
import os
import re
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.db_util import DBUtil
from utils.common_util import CommonUtil, DateTypes
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from pyspark.sql.types import MapType, StringType, DecimalType, IntegerType, Row, DoubleType

"""
搜索词 计算=>平均长宽高 => 计算基本利润率相关数据
依赖 dwd_st_asin_measure 表 dim_asin_detail 表
输出为 Dwd_st_volume_fba
支持所有站点 日期类型

#  更新利润率
update us_aba_last_30_day
set gross_profit_fee_sea = tb_2.gross_profit_fee_sea,
	gross_profit_fee_air=tb_2.gross_profit_fee_air

from us_aba_profit_gross_last30day tb_2
where tb_2.search_term_id = id
"""


def get_Fba_Fee(longVal: float,
                width: float,
                high: float,
                weight: float,
                ):
    fee_type = 0
    fba_fee = 0
    if (longVal <= 36 and width <= 28 and high <= 1.6 and weight <= 113.5):
        fee_type = 1
        fba_fee = 3.22
    elif (longVal <= 36 and width <= 28 and high <= 1.6 and weight > 113.5 and weight <= 227):
        fee_type = 2
        fba_fee = 3.4
    elif (longVal <= 36 and width <= 28 and high <= 1.6 and weight > 227 and weight <= 340.5):
        fee_type = 3
        fba_fee = 3.58
    elif (longVal <= 36 and width <= 28 and high <= 1.6 and weight > 340.5 and weight <= 454):
        fee_type = 4
        fba_fee = 3.77
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight <= 113.5):
        fee_type = 5
        fba_fee = 3.86
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 113.5 and weight <= 227):
        fee_type = 6
        fba_fee = 4.08
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 227 and weight <= 340.5):
        fee_type = 7
        fba_fee = 4.24
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 340.5 and weight <= 454):
        fee_type = 8
        fba_fee = 4.75
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 454 and weight <= 681):
        fee_type = 9
        fba_fee = 5.4
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 681 and weight <= 908):
        fee_type = 10
        fba_fee = 5.69
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 908 and weight <= 1135):
        fee_type = 11
        fba_fee = 6.1
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 1135 and weight <= 1362):
        fee_type = 12
        fba_fee = 6.39
    elif (longVal <= 43 and width <= 34 and high <= 19 and weight > 1362 and weight <= 9080):
        fee_type = 13
        fba_fee = 7.33
    elif (longVal <= 152.4 and (longVal + 2 * (width + high)) <= 330.2 and weight <= 31780):
        fee_type = 14
        fba_fee = 10.15
    elif (longVal <= 274.32 and (longVal + 2 * (width + high)) <= 419.1 and weight <= 68100):
        fee_type = 15
        fba_fee = 19.47
    elif (longVal <= 274.32 and (longVal + 2 * (width + high)) > 419.1 and weight <= 68100):
        fee_type = 16
        fba_fee = 90.81
    elif (longVal > 274.32 and (longVal + 2 * (width + high)) > 419.1 and weight > 68100):
        fee_type = 17
        fba_fee = 159.32
    return (fee_type, fba_fee)


class DwdStVolumeFba(object):

    def __init__(self, site_name, date_type, date_info):
        self.site_name = site_name
        self.date_info = date_info
        self.date_type = date_type
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)

        # #  注册本地静态方法 udf 返回新函数
        self.udf_parse_volume_reg = F.udf(self.udf_parse_volume, MapType(StringType(), DecimalType(10, 4)))

        self.udf_calc_avg_reg = F.udf(self.udf_calc_avg, DecimalType(10, 4))

        self.udf_calc_profit_reg = F.udf(self.udf_calc_profit, MapType(StringType(), StringType()))

        self.hive_tb = "dwd_st_volume_fba"


    @staticmethod
    def udf_parse_volume(asin_volume: str, site_name: str):
        """
        解析
        :param asin_volume:
        :return:
        """

        def safeIndex(list: list, index: int, default: object = None):
            if index <= len(list) - 1:
                return list[index]
            return default

        def to_float(obj, default: object = None):
            try:
                return float(obj)
            except:
                return default

        resultArr = []
        if asin_volume is not None:
            pattern = r"([0-9.]+)"
            resultArr = re.findall(pattern, asin_volume, flags=re.IGNORECASE)

        #  倒序
        resultArr.sort(reverse=True)

        longVal = to_float(safeIndex(resultArr, 0, None), None)
        width = to_float(safeIndex(resultArr, 1, None), None)
        height = to_float(safeIndex(resultArr, 2, None), None)

        if site_name == 'us':
            #  英寸
            rate = 2.54
        else:
            rate = 1

        return {
            "long": None if longVal is None else round(rate * longVal, 4),
            "width": None if width is None else round(rate * width, 4),
            "height": None if height is None else round(rate * height, 4),
        }

    @staticmethod
    def udf_calc_avg(
            val1_col=0,
            val2_col=0,
            val3_col=0,
            val1_count=0,
            val2_count=0,
            val3_count=0
    ):
        val1_col = val1_col or 0
        val2_col = val2_col or 0
        val3_col = val3_col or 0
        val1_count = val1_count or 0
        val2_count = val2_count or 0
        val3_count = val3_count or 0

        arr = [(val1_count, val1_col), (val2_count, val2_col), (val3_count, val3_col)]

        # 从小到大排序
        arr.sort(key=lambda it: it[0], reverse=False)

        val1, count1 = arr[0]
        val2, count2 = arr[1]
        val3, count3 = arr[2]
        if (count1 == count2 and count2 == count3):
            return round((val1 + val2 + val3) / 3, 2)
        if (count2 / count3 >= 0.3):
            return round((val2 + val3) / 2, 2)
        else:
            return val3

    @staticmethod
    def udf_sort_val():
        """
        长宽高重新排序
        :return:
        """

        def udf_sort_inner(val1, val2, val3):
            # 从大到小排序
            arr = [val1 or 0, val2 or 0, val3 or 0, ]
            arr.sort(reverse=True)
            return {
                "long": arr[0],
                "width": arr[1],
                "height": arr[2],
            }

        return F.udf(udf_sort_inner, MapType(StringType(), DecimalType()))

    @staticmethod
    def udf_calc_fba_reg(config_dict: dict):
        """
        根据配置表计算fba费用
        :param config_dict:
        :return:
        """

        def udf_calc_fba(one_categoy_id: str, price):
            if price is None:
                return None
            config_row = config_dict.get(str(one_categoy_id))
            if config_row is None:
                #  默认佣金比例是0.15
                return price * 0.15
            else:
                calc_type = config_row['calc_type']
                calc_json = json.loads(config_row['config_json'])
                if calc_type == '价格分段':
                    min = calc_json['min']
                    rate_1 = calc_json['rate_1']
                    rate_2 = calc_json['rate_2']
                    break_val = calc_json['break_val']

                    if price <= break_val:
                        return max(min, price * rate_1)
                    else:
                        return price * rate_2


                elif calc_type == '佣金分段':
                    min = calc_json['min']
                    rate_1 = calc_json['rate_1']
                    rate_2 = calc_json['rate_2']
                    break_val = calc_json['break_val']

                    if price * rate_1 <= break_val:
                        return max(min, price * rate_1)
                    else:
                        return price * rate_2


                elif calc_type == '最小限制':
                    rate = calc_json['rate']
                    min = calc_json['min']
                    return max(min, price * rate)
                    pass
                elif calc_type == '固定比率':
                    rate = calc_json['rate']
                    return price * rate
            pass

        return F.udf(udf_calc_fba, DoubleType())

    @staticmethod
    def udf_calc_profit(long, width, high, weight, tmp_cost_all_sea, price):
        longVal_before = long or 0
        width_before = width or 0
        high_before = high or 0
        weight_before = weight or 0
        tmpCost = tmp_cost_all_sea or 0
        price_val = price or 0

        fee_type_before, fba_fee_before = get_Fba_Fee(longVal_before, width_before, high_before, weight_before)

        cost_sum = tmpCost + fba_fee_before

        count = 0

        long_result = longVal_before
        width_result = width_before
        high_result = high_before
        fee_type_result = fee_type_before
        fba_fee_result = fba_fee_before
        breakFlag = False

        val_43 = 43
        if cost_sum > price_val:
            while (long_result >= val_43 and count < 4):
                tmpVal1 = long_result / 2
                tmpVal2 = high_result * 2
                breakFlag = tmpVal1 <= tmpVal2

                tmp_list = [tmpVal1, width_result, tmpVal2]
                tmp_list.sort(reverse=True)

                long_result = tmp_list[0]
                width_result = tmp_list[1]
                high_result = tmp_list[2]

                count = count + 1

                fee_type_result, fba_fee_result = get_Fba_Fee(long_result, width_result, high_result, weight_before)

                if breakFlag:
                    break

        if breakFlag or (count == 4 and long_result > val_43):
            long_result = longVal_before
            width_result = width_before
            high_result = high_before
            fee_type_result = fee_type_before
            fba_fee_result = fba_fee_before

        # result_list.append()
        return {
            "long": long_result,
            "width": width_result,
            "high": high_result,
            "fee_type": fee_type_result,
            "fba_fee": fba_fee_result,
            "long_before": longVal_before,
            "width_before": width_before,
            "high_before": high_before
        }

    def get_repartition_num(self):
        """
        根据 date_type 设置文件块数
        :return:
        """
        if self.date_type == DateTypes.day.name:
            return 1
        if self.date_type == DateTypes.week.name:
            return 2
        if self.date_type == DateTypes.month.name:
            return 2
        if self.date_type == DateTypes.last30day.name:
            return 2
        return 10

    def run(self):
        sql = f"""
    select dsam.search_term,
       osk.st_key       as search_term_id,
       dsam.asin,
       asin_volume,
       asin_weight,
       st_bsr_cate_1_id_new       as category_first_id,
       st_bsr_cate_current_id_new as category_id,
       asin_price
from (
         select search_term,
                asin
         from dwd_st_asin_measure
         where date_type = '{CommonUtil.get_rel_date_type('dwd_st_asin_measure', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dsam
         left join
     (
         select search_term,
                st_bsr_cate_1_id_new,
                st_bsr_cate_current_id_new
         from dim_st_detail
     where date_type = '{CommonUtil.get_rel_date_type('dim_st_detail', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dsd on dsd.search_term = dsam.search_term
         left join
     (
         select asin,
                asin_weight,
                asin_volume,
                asin_price
         from dim_asin_detail
             where date_type = '{CommonUtil.get_rel_date_type('dim_asin_detail', self.date_type)}'
           and date_info = '{self.date_info}'
           and site_name = '{self.site_name}'
     ) dad on dad.asin = dsam.asin
         inner join
     (
         select search_term,
                st_key
         from ods_st_key
         where site_name = '{self.site_name}'
     ) osk on osk.search_term = dsam.search_term
                """
        print("======================查询sql如下======================")
        print(sql)
        df_all = self.spark.sql(sql)

        # 长宽高重新排序
        df_all = df_all.withColumn("tmp_row", self.udf_parse_volume_reg(F.col("asin_volume"), F.lit(self.site_name)))
        df_all = df_all.withColumn("long", F.col("tmp_row").getField("long"))
        df_all = df_all.withColumn("width", F.col("tmp_row").getField("width"))
        df_all = df_all.withColumn("height", F.col("tmp_row").getField("height"))
        df_all.drop("tmp_row")

        st_agg_one = df_all.groupBy("search_term_id") \
            .agg(
            F.first("categoy_id").alias("categoy_id"),
            F.first("current_categoy_id").alias("current_categoy_id"),
            F.first("search_term").alias("search_term"),
            F.avg("asin_price").alias("price"),
            # 磅 => 克 并过滤无效数据 todo 其他us站点的
            F.expr("round(avg(asin_weight * 453.59237),4)").alias("weight")
        ).select(
            F.col("search_term"),
            F.col("search_term_id"),
            F.col("categoy_id"),
            F.col("current_categoy_id"),
            F.col("price"),
            F.col("weight")
        )

        long_df = df_all.withColumn("flag",
                                    F.when(F.expr("0 < long and long <= 50"), F.lit(1))
                                    .when(F.expr("50 < long and long <= 100"), F.lit(2))
                                    .when(F.expr("100 < long and long <= 500"), F.lit(3))
                                    .otherwise(None)
                                    ) \
            .groupBy(F.col("search_term_id")) \
            .pivot("flag", [1, 2, 3]) \
            .agg(F.avg("long").cast(DecimalType(10, 2)).alias("val"), F.count("flag").alias("row_count")) \
            .select(
            F.col("search_term_id"),
            self.udf_calc_avg_reg(
                F.col("1_val"), F.col("2_val"), F.col("3_val"),
                F.col("1_row_count"), F.col("2_row_count"), F.col("3_row_count")
            ).alias("long")
        ).cache()

        width_df = df_all.withColumn("flag",
                                     F.when(F.expr("0 < width and width <= 50"), F.lit(1))
                                     .when(F.expr("50 < width and width <= 100"), F.lit(2))
                                     .when(F.expr("100 < width and width <= 500"), F.lit(3))
                                     .otherwise(None)
                                     ) \
            .groupBy(F.col("search_term_id")) \
            .pivot("flag", [1, 2, 3]) \
            .agg(F.avg("width").cast(DecimalType(10, 2)).alias("val"), F.count("flag").alias("row_count")) \
            .select(
            F.col("search_term_id"),
            self.udf_calc_avg_reg(
                F.col("1_val"), F.col("2_val"), F.col("3_val"),
                F.col("1_row_count"), F.col("2_row_count"), F.col("3_row_count")
            ).alias("width")
        ).cache()

        height_df = df_all.withColumn("flag",
                                      F.when(F.expr("0 < height and height <= 50"), F.lit(1))
                                      .when(F.expr("50 < height and height <= 100"), F.lit(2))
                                      .when(F.expr("100 < height and height <= 500"), F.lit(3))
                                      .otherwise(None)
                                      ) \
            .groupBy(F.col("search_term_id")) \
            .pivot("flag", [1, 2, 3]) \
            .agg(F.avg("height").cast(DecimalType(10, 2)).alias("val"), F.count("flag").alias("row_count")) \
            .select(
            F.col("search_term_id"),
            self.udf_calc_avg_reg(
                F.col("1_val"), F.col("2_val"), F.col("3_val"),
                F.col("1_row_count"), F.col("2_row_count"), F.col("3_row_count")
            ).alias("height")
        ).cache()

        st_volume_info = st_agg_one \
            .join(long_df, "search_term_id") \
            .join(width_df, "search_term_id") \
            .join(height_df, "search_term_id") \
            .select(
            st_agg_one["search_term"],
            st_agg_one["search_term_id"],
            st_agg_one["categoy_id"],
            st_agg_one["current_categoy_id"],
            st_agg_one["price"],
            st_agg_one["weight"],
            long_df["long"],
            width_df["width"],
            height_df["height"].alias("high")
        ).fillna({
            "long": 0,
            "width": 0,
            "high": 0,
            "categoy_id": 0,
            "current_categoy_id": 0,
            "price": 0,
            "weight": 0
        })

        conn_info = DBUtil.get_connection_info("postgresql", "us")
        config_sql1 = f"""
        select categoy_id as one_categoy_id,
               categoy_name,
               referral_fee_formula,
               upfc.calc_type,
               upfc.config_json,
               adv::decimal(10, 3),
               (
                   select round(avg(return_ratio) / 100, 2)::decimal(10, 3)
                   from us_aba_profit_category_insights
               )          as return_ratio
        from us_profit_fba_config upfc
                 left join us_profit_adv upa on upa.category = upfc.categoy_name    
        """
        df_profit_join = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=config_sql1
        ).cache()

        calc_profit_dict = {str(row['one_categoy_id']): row.asDict() for row in df_profit_join.collect()}

        config_sql2 = f"""
        select rel_category_id                         as rel_category_id,
               ubc.one_category_id                     as one_category_id,
               round(avg(cost), 2)::decimal(10, 3)     as cost_rate,
               round(avg(avg_cost), 2)::decimal(10, 3) as avg_cost_rate
    from (
             select upcn.name                                    as header_name,
                    (replace(cost, '%', '')::decimal) / 100      as cost,
                    (replace(avg_cost, '%', '') ::decimal) / 100 as avg_cost,
                    first_name,
                    last_name,
                    coalesce(tmp2.id, tmp1.id, upcn.first_id)    as rel_category_id
             from public.us_profit_cost_new upcn
                      left join (
                 select min(id)                   as id,
                        replace(en_name, ' ', '') as name
                 from us_bs_category
                 where nodes_num = 2
                 group by en_name
             ) tmp1 on upcn.first_name = tmp1.name
                      left join (
                 select min(id)                   as id,
                        replace(en_name, ' ', '') as name
                 from us_bs_category
                 where nodes_num > 2
                 group by en_name
             ) tmp2 on upcn.last_name = tmp2.name
         ) tmp
             inner join us_bs_category ubc on ubc.id = tmp.rel_category_id
    group by one_category_id, rel_category_id
        """
        df_cost_join = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=config_sql2
        ).cache()

        # 平均成本比例
        def_cost_rate = df_cost_join.select(F.avg("cost_rate").cast(DecimalType(10, 3)).alias("cost_rate")).first()[0]
        # 平均广告费
        def_adv = df_profit_join.select(F.avg("adv").cast(DecimalType(10, 3)).alias("adv")).first()[0]
        # 平均退款率
        def_return_ratio = df_profit_join.select(F.avg("return_ratio").cast(DecimalType(10, 3)).alias("return_ratio")).first()[0]
        # 头程空运 运费比例
        freight_air_rate = 8.55
        # 头程海运 运费比例
        freight_sea_rate = 2.06

        df_save = st_volume_info \
            .join(df_profit_join, st_volume_info['categoy_id'].eqNullSafe(df_profit_join['one_categoy_id']), "left") \
            .join(df_cost_join, st_volume_info['current_categoy_id'].eqNullSafe(df_cost_join['rel_category_id']), "left")

        df_save = df_save.select(
            st_volume_info["search_term"],
            st_volume_info["search_term_id"],
            st_volume_info["categoy_id"],
            st_volume_info["current_categoy_id"],
            st_volume_info["price"],
            st_volume_info["long"],
            st_volume_info["width"],
            st_volume_info["high"],
            st_volume_info["weight"],
            df_profit_join["return_ratio"],
            df_profit_join["adv"],
            df_cost_join['cost_rate']
        )
        df_save = df_save.fillna({
            # 无分类 填充退款率 佣金
            "return_ratio": float(def_return_ratio),
            "cost_rate": float(def_cost_rate),
            "adv": float(def_adv),
        }).cache()

        # 计算公式
        df_save = df_save.withColumn("referral_fee",
                                     F.round(self.udf_calc_fba_reg(calc_profit_dict)(F.col("categoy_id"), F.col("price")), 4))
        # 头程 海运
        df_save = df_save.withColumn("ocean_freight", F.expr(f"weight * {freight_sea_rate} /1000"))
        # 头程 空运
        df_save = df_save.withColumn("air_delivery_fee", F.expr(f"weight * {freight_air_rate} /1000"))
        # 运营费固定 平均售价 * 5%
        df_save = df_save.withColumn("operating_costs", F.expr(f"price * 0.05"))
        # 成本
        df_save = df_save.withColumn("costs", F.expr(f" price * cost_rate "))
        # 广告占比 adv
        df_save = df_save.withColumn("advertise", F.expr(f" price * adv  "))
        # 退款率 * 价格 即为退款额
        df_save = df_save.withColumn("return_ratio", F.expr(f" price * return_ratio "))
        # 除了fba之外的所有的费用
        df_save = df_save.withColumn("tmp_cost_all_sea",
                                     F.expr("ocean_freight  + referral_fee  + return_ratio  + costs + advertise + operating_costs"))

        # 计算利润率
        df_save = df_save.withColumn("tmp_row", self.udf_calc_profit_reg(
            F.col("long"),
            F.col("width"),
            F.col("high"),
            F.col("weight"),
            F.col("tmp_cost_all_sea"),
            F.col("price"),
        ))

        df_save = df_save.withColumn("fba_fee", F.col("tmp_row").getField("fba_fee"))

        df_save = df_save.withColumn("gross_profit_fee_sea",
                                     F.expr(
                                         "(price-(ocean_freight +referral_fee+return_ratio +costs+advertise+operating_costs + fba_fee))/price")
                                     .cast(DecimalType(10, 3)))

        df_save = df_save.withColumn("gross_profit_fee_air",
                                     F.expr(
                                         "(price-(air_delivery_fee +referral_fee+return_ratio +costs+advertise+operating_costs + fba_fee))/price")
                                     .cast(DecimalType(10, 3)))

        df_save = df_save.select(
            F.col("search_term"),
            F.col("search_term_id"),
            F.col("categoy_id"),
            st_volume_info["current_categoy_id"],
            F.col("weight").cast(DecimalType(10, 3)),
            F.col("price").cast(DecimalType(10, 3)),
            F.col("referral_fee").cast(DecimalType(10, 3)),

            # 长宽高之前
            F.col("tmp_row").getField("long_before").cast(DecimalType(10, 3)).alias("long_before"),
            F.col("tmp_row").getField("width_before").cast(DecimalType(10, 3)).alias("width_before"),
            F.col("tmp_row").getField("high_before").cast(DecimalType(10, 3)).alias("high_before"),
            # 计算后
            F.col("tmp_row").getField("long").cast(DecimalType(10, 3)).alias("longs"),
            F.col("tmp_row").getField("width").cast(DecimalType(10, 3)).alias("width"),
            F.col("tmp_row").getField("high").cast(DecimalType(10, 3)).alias("high"),
            F.col("tmp_row").getField("fba_fee").cast(DecimalType(10, 3)).alias("fba_fee"),
            F.col("tmp_row").getField("fee_type").cast(IntegerType()).alias("fee_type"),

            F.col("return_ratio"),
            F.col("ocean_freight"),
            F.col("air_delivery_fee"),
            F.col("operating_costs"),
            F.col("costs"),
            F.col("advertise"),
            F.col("gross_profit_fee_sea"),
            F.col("gross_profit_fee_air"),

            F.col("cost_rate"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_type).alias("date_type"),
            F.lit(self.date_info).alias("date_info")
        ).where("weight < 10000000")

        # 分区数量调整
        df_save = df_save.repartition(self.get_repartition_num())
        partition_dict = {
            "site_name": self.site_name,
            "date_type": self.date_type,
            "date_info": self.date_info,
        }
        hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict)
        HdfsUtils.delete_hdfs_file(hdfs_path)
        partition_by = list(partition_dict.keys())
        print(f"当前存储的表名为::self.hive_tb,分区为{partition_by}", )
        df_save.write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=partition_by)
        print("success")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwdStVolumeFba(site_name, date_type, date_info)
    obj.run()