import os import sys from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType sys.path.append(os.path.dirname(sys.path[0])) from utils.common_util import CommonUtil from utils.spark_util import SparkUtil from pyspark.sql import functions as F class VerifyRank(object): def __init__(self): self.spark = SparkUtil.get_spark_session("{self.__class__.__name__}") def run(self): sql = f""" select search_term, rank, date_info from ods_brand_analytics where site_name = 'us' and date_type = 'week' and date_info >= '2024-01' and rank < 100000 """ df_all = self.spark.sql(sql).repartition(40, 'search_term').cache() def leave_one_out_means(structs): ranks = [x['rank'] for x in structs] date_infos = [x['date_info'] for x in structs] total_sum = sum(ranks) n = len(ranks) if n > 1: means = [round((total_sum - rank) / (n - 1), 2) for rank in ranks] else: means = [ranks[0]] result = [{"means": mean, "date_info": date_info} for mean, date_info in zip(means, date_infos)] return result leave_one_out_means_udf = F.udf(leave_one_out_means, ArrayType(StructType([ StructField("means", FloatType(), True), StructField("date_info", StringType(), True) ]))) df_agg = df_all.groupBy("search_term").agg( F.collect_list(F.struct("rank", "date_info")).alias("collect_row") # F.collect_list("rank").alias("values") ) df_agg = df_agg.withColumn( "collect_row", leave_one_out_means_udf(F.col("collect_row")) ) def calc_quantiles(structs): values = [x['means'] for x in structs] values = sorted(values) # 将组内的数值进行排序 n = len(values) # 计算 Q1 和 Q3 的位置(基于 25% 和 75% 的位置) q1_index = int(n * 0.25) q3_index = int(n * 0.75) if n > 1: q1 = values[q1_index] q3 = values[q3_index] else: q1 = values[0] q3 = values[0] return [float(q1), float(q3)] quantile_udf = F.udf(calc_quantiles, ArrayType(FloatType())) df_agg = df_agg.withColumn( "quantiles", quantile_udf(F.col("collect_row")) ).withColumn( "q1", F.col("quantiles")[0] ).withColumn( "q3", F.col("quantiles")[1] ).withColumn( "iqr", F.expr("q3 - q1") ).withColumn( "lower_bound", F.expr("q1 - 100 * iqr") ).withColumn( "upper_bound", F.expr("q3 + 100 * iqr") ).select( 'search_term', 'collect_row', 'lower_bound', 'upper_bound' ).repartition(40, 'search_term') df_save = df_agg.withColumn( "filtered_collect_row", F.filter( "collect_row", lambda x: (x["means"] < F.col("lower_bound")) | (x["means"] > F.col("upper_bound")) ) ).filter( F.size(F.col("filtered_collect_row")) > 0 ).withColumn( "has_2024_08", F.exists( "filtered_collect_row", lambda x: x["date_info"].like("2024-08%") ) ).filter( ~F.col("has_2024_08") # 过滤掉包含 '2024-08' 的行 ).select( 'search_term', 'filtered_collect_row', 'lower_bound', 'upper_bound' ) df_save.show(20, False) if __name__ == '__main__': site_name = CommonUtil.get_sys_arg(1, None) date_type = CommonUtil.get_sys_arg(2, None) date_info = CommonUtil.get_sys_arg(3, None) obj = VerifyRank() obj.run()