verify_rank.py 3.87 KB
import os
import sys

from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType

sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F


class VerifyRank(object):

    def __init__(self):
        self.spark = SparkUtil.get_spark_session("{self.__class__.__name__}")

    def run(self):
        sql = f"""
            select 
                search_term,
                rank,
                date_info
            from ods_brand_analytics
            where site_name = 'us'
              and date_type = 'week'
              and date_info >= '2024-01'
              and rank < 100000
        """
        df_all = self.spark.sql(sql).repartition(40, 'search_term').cache()

        def leave_one_out_means(structs):
            ranks = [x['rank'] for x in structs]
            date_infos = [x['date_info'] for x in structs]

            total_sum = sum(ranks)
            n = len(ranks)
            if n > 1:
                means = [round((total_sum - rank) / (n - 1), 2) for rank in ranks]
            else:
                means = [ranks[0]]
            result = [{"means": mean, "date_info": date_info} for mean, date_info in zip(means, date_infos)]
            return result

        leave_one_out_means_udf = F.udf(leave_one_out_means, ArrayType(StructType([
            StructField("means", FloatType(), True),
            StructField("date_info", StringType(), True)
        ])))

        df_agg = df_all.groupBy("search_term").agg(
            F.collect_list(F.struct("rank", "date_info")).alias("collect_row")
            # F.collect_list("rank").alias("values")
        )
        df_agg = df_agg.withColumn(
            "collect_row", leave_one_out_means_udf(F.col("collect_row"))
        )

        def calc_quantiles(structs):
            values = [x['means'] for x in structs]
            values = sorted(values)  # 将组内的数值进行排序
            n = len(values)
            # 计算 Q1 和 Q3 的位置(基于 25% 和 75% 的位置)
            q1_index = int(n * 0.25)
            q3_index = int(n * 0.75)
            if n > 1:
                q1 = values[q1_index]
                q3 = values[q3_index]
            else:
                q1 = values[0]
                q3 = values[0]
            return [float(q1), float(q3)]
        quantile_udf = F.udf(calc_quantiles, ArrayType(FloatType()))

        df_agg = df_agg.withColumn(
            "quantiles", quantile_udf(F.col("collect_row"))
        ).withColumn(
            "q1", F.col("quantiles")[0]
        ).withColumn(
            "q3", F.col("quantiles")[1]
        ).withColumn(
            "iqr", F.expr("q3 - q1")
        ).withColumn(
            "lower_bound", F.expr("q1 - 100 * iqr")
        ).withColumn(
            "upper_bound", F.expr("q3 + 100 * iqr")
        ).select(
            'search_term', 'collect_row', 'lower_bound', 'upper_bound'
        ).repartition(40, 'search_term')

        df_save = df_agg.withColumn(
            "filtered_collect_row",
            F.filter(
                "collect_row",
                lambda x: (x["means"] < F.col("lower_bound")) | (x["means"] > F.col("upper_bound"))
            )
        ).filter(
            F.size(F.col("filtered_collect_row")) > 0
        ).withColumn(
            "has_2024_08",
            F.exists(
                "filtered_collect_row",
                lambda x: x["date_info"].like("2024-08%")
            )
        ).filter(
            ~F.col("has_2024_08")  # 过滤掉包含 '2024-08' 的行
        ).select(
            'search_term', 'filtered_collect_row', 'lower_bound', 'upper_bound'
        )

        df_save.show(20, False)


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_type = CommonUtil.get_sys_arg(2, None)
    date_info = CommonUtil.get_sys_arg(3, None)
    obj = VerifyRank()
    obj.run()