import os import sys sys.path.append(os.path.dirname(sys.path[0])) from utils.common_util import CommonUtil, DateTypes from utils.spark_util import SparkUtil from pyspark.sql import functions as F from pyspark.sql.types import IntegerType class DwtAbaLast365(object): def __init__(self, site_name, date_type, date_info): self.site_name = site_name self.date_info = date_info self.date_type = date_type self.date_type_original = DateTypes.month.name assert date_type in [DateTypes.month.name, DateTypes.year.name], "date_type 不合法!!" app_name = f"{self.__class__.__name__}:{self.site_name}:{self.date_type}:{self.date_info}" self.spark = SparkUtil.get_spark_session(app_name) pass def run(self): # 获取过去的12个月 last_12_month = [] for i in range(0, 12): last_12_month.append(CommonUtil.get_month_offset(self.date_info, -i)) print(f"过去12个月为{last_12_month}") sql = f""" select search_term, id, date_info from dwt_aba_st_analytics where site_name = '{self.site_name}' and date_type = '{self.date_type_original}' and date_info in ({CommonUtil.list_to_insql(last_12_month)}) """ df_all = self.spark.sql(sql) df_all = df_all.groupBy("id").agg( F.first("search_term").alias("search_term"), # 所有出现的月 F.concat_ws( ",", F.sort_array( F.collect_set( F.split(F.col("date_info"), "-")[1].cast(IntegerType()) ) ) ).alias("total_appear_month") ).cache() df_all.show(20, False) if __name__ == '__main__': site_name = CommonUtil.get_sys_arg(1, None) date_type = CommonUtil.get_sys_arg(2, None) date_info = CommonUtil.get_sys_arg(3, None) obj = DwtAbaLast365(site_name=site_name, date_type=date_type, date_info=date_info) obj.run()