keyword_check.py 3.82 KB
import os
import re
import sys

sys.path.append(os.path.dirname(sys.path[0]))

from utils.spark_util import SparkUtil
from utils.db_util import DBUtil
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, MapType


class KeywordCheck(object):

    def __init__(self):
        app_name = f"{self.__class__.__name__}"
        self.spark = SparkUtil.get_spark_session(app_name)
        pass

    def st_brand_label(self, brand_list):
        pattern = re.compile(r'\b(?:{})\b'.format('|'.join([re.escape(x) for x in brand_list])), flags=re.IGNORECASE)
        def udf_st_brand_label(search_term):
            match_brand = None
            label_type = 0
            if len(brand_list) >0:
                result = pattern.search(search_term)
                if bool(result):
                    match_brand = str(result.group())
                    label_type = 1
            return {"match_brand": match_brand, "label_type": label_type}
        return F.udf(udf_st_brand_label, MapType(StringType(), StringType(), True))

    def run(self):
        # 读取关键词文件
        hdfs_path = 'hdfs://hadoop15:8020/home/big_data_selection/tmp/关键词.txt'
        df_keyword = self.spark.read.text(hdfs_path)
        df_keyword = df_keyword.withColumnRenamed(
            "value", "keyword"
        ).withColumn(
            'keyword',
            F.lower(F.trim('keyword'))
        ).cache()
        print("关键词如下:")
        df_keyword.show(10)

        # 读取品牌词库
        sql1 = f"""
            select 
                lower(trim(brand_name)) as brand_name 
            from brand_alert_erp 
            where brand_name is not null 
            group by brand_name
        """
        con_info = DBUtil.get_connection_info("postgresql_cluster", "us")
        df_brand = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=con_info['url'],
            pwd=con_info['pwd'],
            username=con_info['username'],
            query=sql1
        ).cache()
        print("品牌词如下:")
        df_brand.show(10)

        # 获取品牌词黑名单
        sql2 = f"""
            select 
                lower(trim(character_name)) as character_name, 
                1 as black_flag 
            from match_character_dict 
            where match_type = '品牌词库黑名单'
            group by character_name
        """
        con_info = DBUtil.get_connection_info("mysql", "us")
        df_brand_black = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=con_info["url"],
            pwd=con_info["pwd"],
            username=con_info["username"],
            query=sql2
        ).cache()
        print("品牌词黑名单如下:")
        df_brand_black.show(10)

        df_brand = df_brand.join(
            df_brand_black, df_brand['brand_name'] == df_brand_black['character_name'], 'left_anti'
        ).cache()
        df_brand.show(10)

        # df_save = df_keyword.join(
        #     df_brand, df_keyword['keyword'] == df_brand['brand_name'], 'left'
        # ).select(
        #     'keyword', 'brand_flag', 'black_flag'
        # ).fillna({
        #     'brand_flag': 0,
        #     'black_flag': 0
        # })

        # 将数据转换成pandas_df
        pd_df = df_brand.toPandas()
        # 提取品牌词库list
        brand_list = pd_df["brand_name"].values.tolist()
        df_map = self.st_brand_label(brand_list)(df_keyword['keyword'])
        df_save = df_keyword.withColumn(
            'brand_name', df_map['match_brand']
        ).withColumn(
            'brand_flag', df_map['label_type']
        )
        df_save.filter('brand_name is not null').show(truncate=False)

        # df_save.write.saveAsTable(name='tmp_keyword_check', format='hive', mode='append')
        # print("success")


if __name__ == '__main__':
    obj = KeywordCheck()
    obj.run()