import os
import sys

from pyspark.sql.types import MapType, StringType

from utils.db_util import DBUtil

sys.path.append(os.path.dirname(sys.path[0]))

from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F, Window
from yswg_utils.common_udf import parse_bsr_url

"""
nsr分类树解析
"""


class DimNsrCategoryTree(object):

    def __init__(self, site_name):
        self.site_name = site_name
        app_name = f"{self.__class__.__name__}:{site_name}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.hive_tb = "dim_nsr_category_tree"
        self.parse_bsr_url_reg = F.udf(parse_bsr_url, MapType(StringType(), StringType()))

    def run(self):
        suffix = f"_{self.site_name}"
        if self.site_name == 'us':
            suffix = ''

        sql = f"""
        select ch_name,
               en_name,
               nodes_num,
               path,
               updated_at,
               is_show,
               one_category_id,
               and_en_name,
               leaf_node,
               delete_time
        from selection{suffix}.{self.site_name}_new_releases
"""
        print("======================查询mysql数据源sql如下======================")
        print(sql)

        conn_info = DBUtil.get_connection_info("mysql", "us")
        df_save = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=sql
        ).cache()

        df_save = df_save.withColumn("tmp_map", self.parse_bsr_url_reg(F.col("nodes_num"), F.col("path")))

        df_save = df_save.select(
            F.col("tmp_map").getField("category_id").alias("category_id"),
            F.col("tmp_map").getField("category_parent_id").alias("category_parent_id"),
            F.col("tmp_map").getField("category_first_id").alias("category_first_id"),
            F.col("en_name"),
            F.col("ch_name"),
            F.col("nodes_num"),
            F.col("path"),
            F.col("and_en_name"),
            F.col("leaf_node"),
            F.col("updated_at"),
            F.col("delete_time"),
            F.lit(self.site_name).alias("site_name"),
        )

        # 去重
        df_save = df_save.withColumn("row_number", F.row_number().over(
            window=Window.partitionBy(['category_id', 'category_parent_id']).orderBy(F.col("delete_time").desc_nulls_first()))
                                     )
        df_save = df_save.where("row_number == 1")
        df_save = df_save.drop("row_number")
        df_save = df_save.sort(['nodes_num', 'category_first_id', 'category_id'])

        partition_dict = {
            "site_name": self.site_name,
        }

        df_save = df_save.repartition(1)
        hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict)
        HdfsUtils.delete_file_in_folder(hdfs_path)

        partition_by = list(partition_dict.keys())
        print(f"当前存储的表名为：{self.hive_tb},分区为{partition_by}", )
        df_save.write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=partition_by)
        print("success")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    obj = DimNsrCategoryTree(site_name)
    obj.run()
