import os import sys from pyspark.sql import functions as F from pyspark.sql.types import MapType, StringType from utils.common_util import CommonUtil from utils.templates import Templates sys.path.append(os.path.dirname(sys.path[0])) def udf_col_12month(last_12_month, row_list): for month in last_12_month: print(month) for row in row_list: label = row['label'] age = row['age'] print(label) print(age) print(type(last_12_month)) print(type(row_list)) return { "1", "dahjo", "2", "dahjo", "3", "dahjo", "4", "dahjo", } pass if __name__ == '__main__': tmp = Templates() spark = tmp.create_spark_object("test_df") udf_col_12month_reg = spark.udf.register("udf_col_12month", udf_col_12month, MapType(StringType(), StringType(), valueContainsNull=True)) df_save = spark.createDataFrame( [ (3, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1), (2, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1), (1, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1), (2, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1), ], ('id', 'label', "age") ) df_save.show(n=100, truncate=False) df_save = df_save.withColumn("tmpArr", F.array(list(map(lambda item: F.lit(item), [1, 2, 3, 4, 5])))) df_save = df_save.withColumn("tmpStruct", F.struct(F.col("label"), F.col("age"))) df_save = df_save.groupBy("id").agg( udf_col_12month_reg( CommonUtil.arr_to_spark_col([1, 2, 3, 4, 5]), F.collect_list(F.struct(F.col("label"), F.col("age"))) ).alias("resultMap") ) df_save.show(truncate=False) print("success")