verify_rank.py 3.87 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
import os
import sys

from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType

sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F


class VerifyRank(object):

    def __init__(self):
        self.spark = SparkUtil.get_spark_session("{self.__class__.__name__}")

    def run(self):
        sql = f"""
            select 
                search_term,
                rank,
                date_info
            from ods_brand_analytics
            where site_name = 'us'
              and date_type = 'week'
              and date_info >= '2024-01'
              and rank < 100000
        """
        df_all = self.spark.sql(sql).repartition(40, 'search_term').cache()

        def leave_one_out_means(structs):
            ranks = [x['rank'] for x in structs]
            date_infos = [x['date_info'] for x in structs]

            total_sum = sum(ranks)
            n = len(ranks)
            if n > 1:
                means = [round((total_sum - rank) / (n - 1), 2) for rank in ranks]
            else:
                means = [ranks[0]]
            result = [{"means": mean, "date_info": date_info} for mean, date_info in zip(means, date_infos)]
            return result

        leave_one_out_means_udf = F.udf(leave_one_out_means, ArrayType(StructType([
            StructField("means", FloatType(), True),
            StructField("date_info", StringType(), True)
        ])))

        df_agg = df_all.groupBy("search_term").agg(
            F.collect_list(F.struct("rank", "date_info")).alias("collect_row")
            # F.collect_list("rank").alias("values")
        )
        df_agg = df_agg.withColumn(
            "collect_row", leave_one_out_means_udf(F.col("collect_row"))
        )

        def calc_quantiles(structs):
            values = [x['means'] for x in structs]
            values = sorted(values)  # 将组内的数值进行排序
            n = len(values)
            # 计算 Q1 和 Q3 的位置(基于 25% 和 75% 的位置)
            q1_index = int(n * 0.25)
            q3_index = int(n * 0.75)
            if n > 1:
                q1 = values[q1_index]
                q3 = values[q3_index]
            else:
                q1 = values[0]
                q3 = values[0]
            return [float(q1), float(q3)]
        quantile_udf = F.udf(calc_quantiles, ArrayType(FloatType()))

        df_agg = df_agg.withColumn(
            "quantiles", quantile_udf(F.col("collect_row"))
        ).withColumn(
            "q1", F.col("quantiles")[0]
        ).withColumn(
            "q3", F.col("quantiles")[1]
        ).withColumn(
            "iqr", F.expr("q3 - q1")
        ).withColumn(
            "lower_bound", F.expr("q1 - 100 * iqr")
        ).withColumn(
            "upper_bound", F.expr("q3 + 100 * iqr")
        ).select(
            'search_term', 'collect_row', 'lower_bound', 'upper_bound'
        ).repartition(40, 'search_term')

        df_save = df_agg.withColumn(
            "filtered_collect_row",
            F.filter(
                "collect_row",
                lambda x: (x["means"] < F.col("lower_bound")) | (x["means"] > F.col("upper_bound"))
            )
        ).filter(
            F.size(F.col("filtered_collect_row")) > 0
        ).withColumn(
            "has_2024_08",
            F.exists(
                "filtered_collect_row",
                lambda x: x["date_info"].like("2024-08%")
            )
        ).filter(
            ~F.col("has_2024_08")  # 过滤掉包含 '2024-08' 的行
        ).select(
            'search_term', 'filtered_collect_row', 'lower_bound', 'upper_bound'
        )

        df_save.show(20, False)


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = VerifyRank()
    obj.run()