import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql.types import MapType, StringType
from utils.db_util import DBUtil
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F, Window
from yswg_utils.common_udf import parse_bsr_url

"""
bsr分类树解析
"""


class DimBsrCategoryTree(object):

    def __init__(self, site_name):
        self.site_name = site_name
        app_name = f"{self.__class__.__name__}:{site_name}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.hive_tb = "dim_bsr_category_tree"
        self.parse_bsr_url_reg = F.udf(parse_bsr_url, MapType(StringType(), StringType()))

    def run(self):
        suffix = f"_{self.site_name}"
        if self.site_name == 'us':
            suffix = ''

        sql = f"""
        select ch_name,
               en_name,
               nodes_num,
               path,
               updated_at,
               is_show,
               one_category_id,
               and_en_name,
               leaf_node,
               delete_time,
               redirect_first_id
        from selection{suffix}.{self.site_name}_bs_category
"""
        print("======================查询mysql数据源sql如下======================")
        print(sql)

        conn_info = DBUtil.get_connection_info("mysql", "us")
        df_save = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=sql
        ).cache()

        df_save = df_save.withColumn("tmp_map", self.parse_bsr_url_reg(F.col("nodes_num"), F.col("path")))
        df_save = df_save.withColumn("category_id", F.col("tmp_map").getField("category_id"))
        df_save = df_save.withColumn("category_parent_id", F.col("tmp_map").getField("category_parent_id"))
        df_save = df_save.withColumn("category_first_id", F.col("tmp_map").getField("category_first_id"))

        # 置为null
        df_save = df_save.withColumn("redirect_first_id",
                                     F.when(F.col("redirect_first_id").eqNullSafe(F.col("category_first_id")), None).otherwise(
                                         F.col("redirect_first_id")))

        # 是否重定向
        df_save = df_save.withColumn("redirect_flag", F.when(F.col("redirect_first_id").isNotNull(), 1).otherwise(0))
        # 去重
        df_save = df_save.withColumn("row_number", F.row_number().over(
            window=Window.partitionBy(['category_id', 'category_parent_id']).orderBy(F.col("delete_time").desc_nulls_first()))
                                     )
        df_save = df_save.where("row_number == 1")

        # 获取真实一级分类id
        rel_first_df = df_save.where("delete_time is null").groupby("category_id").agg(
            F.collect_set(F.coalesce("redirect_first_id", "category_first_id")).getItem(0).alias("rel_first_id")
        ).cache()

        df_save = df_save \
            .join(rel_first_df, on=['category_id'], how='left') \
            .select(
            F.col("category_id"),
            F.col("category_parent_id"),
            F.col("category_first_id"),
            # F.col("en_name"),
            df_save['en_name'].alias("en_name"),
            F.col("ch_name"),
            F.col("nodes_num"),
            F.col("path"),
            F.col("and_en_name"),
            F.col("leaf_node"),
            F.col("updated_at"),
            F.col("delete_time"),
            F.col("redirect_flag"),
            F.col("redirect_first_id"),
            # 注意此处 用于确保 rel_first_id不为null
            F.coalesce("rel_first_id", "category_first_id").alias('rel_first_id'),
            # 作废
            F.lit(None).alias("full_name"),
            F.lit(self.site_name).alias("site_name"),
        )

        partition_dict = {
            "site_name": self.site_name,
        }

        df_save = df_save.repartition(1)
        df_save = df_save.sort(['nodes_num', 'category_first_id', 'category_id'])

        # 更新当前表
        CommonUtil.save_or_update_table(
            spark_session=self.spark,
            hive_tb_name=self.hive_tb,
            partition_dict=partition_dict,
            df_save=df_save
        )
        #  保存当日数据到历史表
        day_now = CommonUtil.format_now("%Y-%m-%d")
        df_save_history = df_save.filter("delete_time is null")
        df_save_history = df_save_history.withColumn("date_info", F.lit(day_now))
        CommonUtil.save_or_update_table(
            spark_session=self.spark,
            hive_tb_name="dim_bsr_category_tree_history",
            partition_dict={
                "site_name": self.site_name,
                "date_info": day_now
            },
            df_save=df_save_history
        )

        print("检查数据是否正确中。。。。")
        if site_name != 'us':
            return

        check_sql = f"""
                select category_id, category_first_id, redirect_first_id, path, delete_time
                from dim_bsr_category_tree
                where site_name = 'us'
                  and delete_time is null
                  and category_id in (
                    select category_id
                    from (
                             select category_id,
                                    collect_set(coalesce(redirect_first_id, category_first_id)) rel_set1
                             from dim_bsr_category_tree
                             where site_name = 'us'
                               and delete_time is null
                             group by category_id
                         ) tmp
                    where size(rel_set1) > 1
                )
                order by category_id
        """
        check_df = self.spark.sql(check_sql)
        count = check_df.count()
        if count > 0:
            CommonUtil.send_wx_msg(['wujicang', 'pengyanbing'], "bsr榜单数据检查异常", f"{site_name}榜单树，存在{count}条数据未校验是否重定向！，请检查！！")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    obj = DimBsrCategoryTree(site_name)
    obj.run()
