dws_st_brand_info.py 6.09 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
import os
import sys
import re
from functools import reduce

sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.hdfs_utils import HdfsUtils
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, IntegerType, DoubleType, MapType
from utils.spark_util import SparkUtil
from utils.common_util import CommonUtil, DateTypes

class DwsStBrandInfo(object):
    def __init__(self, site_name, date_type, date_info):
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.hive_table = f"dws_st_brand_info"
        self.hdfs_path = f"/home/{SparkUtil.DEF_USE_DB}/dws/{self.hive_table}/site_name={self.site_name}/date_type={self.date_type}/date_info={self.date_info}"
        self.partitions_num = CommonUtil.reset_partitions(site_name, 10)

        self.year_month = self.get_year_month()
        self.last_year_month = CommonUtil.get_month_offset(self.year_month, -1)

    # # 解析品牌词
    # def st_brand_label(self,brand_list):
    #     def udf_st_brand_label(search_term):
    #         match_brand = None
    #         label_type = 0
    #         for brand in brand_list:
    #             pattern = re.compile(r'\b(?:{})\b'.format(re.escape(str(brand))), flags=re.IGNORECASE)
    #             if bool(pattern.search(search_term)):
    #                 match_brand = str(brand)
    #                 label_type = 1
    #                 break
    #         return {"match_brand": match_brand, "label_type": label_type}
    #     return F.udf(udf_st_brand_label, MapType(StringType(), StringType(), True))

    # 解析品牌词
    def st_brand_label(self, brand_list):
        pattern = re.compile(r'\b(?:{})\b'.format('|'.join([re.escape(x) for x in brand_list])), flags=re.IGNORECASE)
        def udf_st_brand_label(search_term):
            match_brand = None
            label_type = 0
            if len(brand_list) > 0:
                result = pattern.search(search_term)
                if bool(result):
                    match_brand = str(result.group())
                    label_type = 1
            return {"match_brand": match_brand, "label_type": label_type}
        return F.udf(udf_st_brand_label, MapType(StringType(), StringType(), True))

    def get_year_month(self):
        # 根据日期获取当前周
        if self.date_type == DateTypes.week.name:
            sql = f"select year_month from dim_date_20_to_30 where year_week='{self.date_info}'"
            df = self.spark.sql(sqlQuery=sql).toPandas()
            print(list(df.year_month)[0])
            return list(df.year_month)[0]
        elif self.date_type == DateTypes.month.name or date_type == DateTypes.month_week.name:
            return self.date_info

    def run(self):
        sql = f"""
            select 
                search_term 
            from dim_st_detail 
            where site_name = '{self.site_name}' 
              and date_type = '{self.date_type}' 
              and date_info = '{self.date_info}'
        """
        df_st_detail = self.spark.sql(sqlQuery=sql)
        print("sql:", sql)
        # 重分区增加并行度
        df_st_detail = df_st_detail.repartition(80, 'search_term')

        if self.date_type == DateTypes.week.name:
            # 获取品牌词库
            sql = f"""
                select 
                    st_brand_name_lower as brand_name
                from dim_st_brand_info
                where site_name = '{self.site_name}'
                  and date_type = 'month'
                  and date_info in ('{self.last_year_month}','{self.year_month}')
                  and length(st_brand_name_lower) > 1
                  and black_flag = 0
            """
        elif self.date_type == DateTypes.month.name or date_type == DateTypes.month_week.name:
            sql = f"""
                select 
                    st_brand_name_lower as brand_name
                from dim_st_brand_info
                where site_name = '{self.site_name}'
                  and date_type = '{self.date_type}'
                  and date_info = '{self.date_info}'
                  and length(st_brand_name_lower) > 1
                  and black_flag = 0
            """
        df_st_brand = self.spark.sql(sqlQuery=sql)
        df_st_brand = df_st_brand.dropDuplicates(['brand_name'])
        print("sql:", sql)
        # 将数据转换成pandas_df
        pd_df = df_st_brand.toPandas()
        # 提取品牌词库list
        brand_list = pd_df["brand_name"].values.tolist()

        df_st_map = self.st_brand_label(brand_list)(df_st_detail.search_term)
        df_st_detail = df_st_detail.withColumn("first_match_brand", df_st_map["match_brand"])
        df_st_detail = df_st_detail.withColumn("st_brand_label", df_st_map["label_type"])

        # 补全分区字段
        df_save = df_st_detail.select(
            F.col('search_term'),
            F.col('first_match_brand'),
            F.col('st_brand_label').cast('int').alias('st_brand_label'),
            F.date_format(F.current_timestamp(), 'yyyy-MM-dd HH:mm:SS').alias('created_time'),
            F.date_format(F.current_timestamp(), 'yyyy-MM-dd HH:mm:SS').alias('updated_time'),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_type).alias("date_type"),
            F.lit(self.date_info).alias("date_info")
        )

        df_save = df_save.repartition(self.partitions_num)
        partition_by = ["site_name", "date_type", "date_info"]
        print(f"清除hdfs目录中.....{self.hdfs_path}")
        HdfsUtils.delete_file_in_folder(self.hdfs_path)
        print(f"当前存储的表名为:{self.hive_table},分区为{partition_by}")
        df_save.write.saveAsTable(name=self.hive_table, format='hive', mode='append', partitionBy=partition_by)
        print("success")


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = DwsStBrandInfo(site_name, date_type, date_info)
    obj.run()