import os import sys sys.path.append(os.path.dirname(sys.path[0])) from pyspark.sql.types import StringType, MapType from utils.common_util import CommonUtil from utils.spark_util import SparkUtil from pyspark.sql import functions as F import numpy as np class DwtStMarket(object): def __init__(self, site_name, date_type, date_info): self.site_name = site_name self.date_type = date_type self.date_info = date_info app_name = f"{self.__class__.__name__}:{site_name}:{date_type}:{date_info}" self.spark = SparkUtil.get_spark_session(app_name) self.last_12_month = self.get_last_12_month() @staticmethod def udf_market_fun_reg(last_12_month: list): def udf_market_fun(rows: list): index_map = {} for i, val in enumerate(last_12_month): index_map[val] = str(i) pass row_map = {} for row in rows: row_map[index_map[row['date_info']]] = row pass # 默认值都是一般市场 market_cycle_type = 4 # 计算市场 max_avg_val = 0 result_val = {} rate_list = [] for key in row_map: month_seq = int(key) + 1 row = row_map.get(key) vol_val = row['st_search_num'] rank_val = row['st_rank'] if rank_val is not None: result_val[f"rank_{month_seq}"] = str(rank_val) pass if vol_val is not None: result_val[f"rank_{month_seq}_search_volume"] = str(vol_val) pass # 计算最近6个月的月搜索量增长率 if month_seq < 6: last_month_row = row_map.get(str(int(key) + 1)) if last_month_row is not None: last_vol_val = last_month_row['st_search_num'] if last_vol_val is not None and vol_val is not None: # 增长率=(本月-上月)/上月 rate = round((float(vol_val) - float(last_vol_val)) / float(last_vol_val), 4) result_val[f"month_growth_rate_{month_seq}"] = str(rate) rate_list.append(rate) pass value = [row['st_search_num'] for row in rows] # 平均值 avg_val = round(np.sum(value) / 12, 4) # 最大值 max_val = np.max(value) max_avg_val = round((max_val - avg_val) / avg_val, 4) if max_avg_val > 0.8: # 1 季节性市场:该关键词最近12个月的 月搜索量,按照搜索量降序排列,取得最高搜索量 h1,如果(h1-平均值)/平均值>0.8,则该市场为季节性市场 market_cycle_type = 1 pass else: # 2 持续增长市场:该关键词最近6个月有5-6个月,月搜索量增长率大于5%的市场 if len([rate for rate in rate_list if rate >= 0.05]) >= 5: market_cycle_type = 2 # 3:持续衰退市场:词最近6个月有5-6个月,月搜索量增长率小于-5%的市场 elif len([rate for rate in rate_list if rate <= -0.05]) >= 5: market_cycle_type = 3 else: market_cycle_type = 4 pass result_val['market_cycle_type'] = str(market_cycle_type) result_val['max_avg_val'] = str(max_avg_val) result_val['start_month'] = str(last_12_month[0]) result_val['end_month'] = str(last_12_month[11]) return result_val return F.udf(udf_market_fun, MapType(StringType(), StringType())) def get_last_12_month(self): last_12_month = [] for i in range(0, 12): last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i)) return last_12_month def run(self): sql = f""" select search_term, st_search_num, date_info, st_rank from dim_st_detail where site_name = '{self.site_name}' and date_type = '{self.date_type}' and date_info in ({CommonUtil.list_to_insql(self.last_12_month)}) and st_search_num is not null and search_term = 'iphone 16 pro screen protector'; """ df_st_detail = self.spark.sql(sql).cache() df_st_detail.show(10, truncate=False) df_st_detail = df_st_detail.groupby('search_term').agg( self.udf_market_fun_reg (self.last_12_month) (F.collect_list(F.struct("date_info", "st_search_num", "st_rank"))).alias("cal_map") ) df_save = df_st_detail.select( F.col("search_term"), F.col("cal_map").getField("start_month").alias("start_month"), F.col("cal_map").getField("end_month").alias("end_month"), F.col("cal_map").getField("market_cycle_type").alias("market_cycle_type"), F.col("cal_map").getField("rank_1_search_volume").alias("rank_1_search_volume"), F.col("cal_map").getField("rank_2_search_volume").alias("rank_2_search_volume"), F.col("cal_map").getField("rank_3_search_volume").alias("rank_3_search_volume"), F.col("cal_map").getField("rank_4_search_volume").alias("rank_4_search_volume"), F.col("cal_map").getField("rank_5_search_volume").alias("rank_5_search_volume"), F.col("cal_map").getField("rank_6_search_volume").alias("rank_6_search_volume"), F.col("cal_map").getField("rank_7_search_volume").alias("rank_7_search_volume"), F.col("cal_map").getField("rank_8_search_volume").alias("rank_8_search_volume"), F.col("cal_map").getField("rank_9_search_volume").alias("rank_9_search_volume"), F.col("cal_map").getField("rank_10_search_volume").alias("rank_10_search_volume"), F.col("cal_map").getField("rank_11_search_volume").alias("rank_11_search_volume"), F.col("cal_map").getField("rank_12_search_volume").alias("rank_12_search_volume"), F.col("cal_map").getField("max_avg_val").alias("max_avg_val") ) df_save.show() if __name__ == '__main__': site_name = CommonUtil.get_sys_arg(1, None) date_type = CommonUtil.get_sys_arg(2, None) date_info = CommonUtil.get_sys_arg(3, None) obj = DwtStMarket(site_name, date_type, date_info) obj.run()