import os import re import sys sys.path.append(os.path.dirname(sys.path[0])) from utils.spark_util import SparkUtil from utils.db_util import DBUtil from pyspark.sql import functions as F from pyspark.sql.types import StringType, MapType class KeywordCheck(object): def __init__(self): app_name = f"{self.__class__.__name__}" self.spark = SparkUtil.get_spark_session(app_name) pass def st_brand_label(self, brand_list): pattern = re.compile(r'\b(?:{})\b'.format('|'.join([re.escape(x) for x in brand_list])), flags=re.IGNORECASE) def udf_st_brand_label(search_term): match_brand = None label_type = 0 if len(brand_list) >0: result = pattern.search(search_term) if bool(result): match_brand = str(result.group()) label_type = 1 return {"match_brand": match_brand, "label_type": label_type} return F.udf(udf_st_brand_label, MapType(StringType(), StringType(), True)) def run(self): # 读取关键词文件 hdfs_path = 'hdfs://hadoop15:8020/home/big_data_selection/tmp/关键词.txt' df_keyword = self.spark.read.text(hdfs_path) df_keyword = df_keyword.withColumnRenamed( "value", "keyword" ).withColumn( 'keyword', F.lower(F.trim('keyword')) ).cache() print("关键词如下:") df_keyword.show(10) # 读取品牌词库 sql1 = f""" select lower(trim(brand_name)) as brand_name from brand_alert_erp where brand_name is not null group by brand_name """ con_info = DBUtil.get_connection_info("postgresql_cluster", "us") df_brand = SparkUtil.read_jdbc_query( session=self.spark, url=con_info['url'], pwd=con_info['pwd'], username=con_info['username'], query=sql1 ).cache() print("品牌词如下:") df_brand.show(10) # 获取品牌词黑名单 sql2 = f""" select lower(trim(character_name)) as character_name, 1 as black_flag from match_character_dict where match_type = '品牌词库黑名单' group by character_name """ con_info = DBUtil.get_connection_info("mysql", "us") df_brand_black = SparkUtil.read_jdbc_query( session=self.spark, url=con_info["url"], pwd=con_info["pwd"], username=con_info["username"], query=sql2 ).cache() print("品牌词黑名单如下:") df_brand_black.show(10) df_brand = df_brand.join( df_brand_black, df_brand['brand_name'] == df_brand_black['character_name'], 'left_anti' ).cache() df_brand.show(10) # df_save = df_keyword.join( # df_brand, df_keyword['keyword'] == df_brand['brand_name'], 'left' # ).select( # 'keyword', 'brand_flag', 'black_flag' # ).fillna({ # 'brand_flag': 0, # 'black_flag': 0 # }) # 将数据转换成pandas_df pd_df = df_brand.toPandas() # 提取品牌词库list brand_list = pd_df["brand_name"].values.tolist() df_map = self.st_brand_label(brand_list)(df_keyword['keyword']) df_save = df_keyword.withColumn( 'brand_name', df_map['match_brand'] ).withColumn( 'brand_flag', df_map['label_type'] ) df_save.filter('brand_name is not null').show(truncate=False) # df_save.write.saveAsTable(name='tmp_keyword_check', format='hive', mode='append') # print("success") if __name__ == '__main__': obj = KeywordCheck() obj.run()