import os import sys from utils.templates import Templates from pyspark.sql import functions as F sys.path.append(os.path.dirname(sys.path[0])) if __name__ == '__main__': tmp = Templates() spark = tmp.create_spark_object("test_df") df_all = spark.createDataFrame( [ ('2022', 3, "name1", 1), ('2023', 2, "name2", 1), ('2022', 1, "name3", 1), ('2023', 2, "name4", 1), ], ("day", 'id', 'label', "age") ) df_all.show(n=100, truncate=False) # df_save = df_save.withColumn("tmpArr", F.array(list(map(lambda item: F.lit(item), [1, 2, 3, 4, 5])))) # df_save = df_save.withColumn("tmpStruct", F.struct(F.col("label"), F.col("age"))) arr = ['2022', '2023', '2024', '2025', '3033'] # df_agg = df_all.groupBy("id").agg(F.first("label")) df_tmp1 = df_all.groupBy("id") \ .pivot("day", arr).agg( F.first("age").alias("age"), F.first("label").alias("label") ) df_tmp1.show(truncate=False) # for index in range(0, len(arr)): # df_tmp1 = df_tmp1.withColumnRenamed(arr[index], f"age{index + 1}") # df_agg = df_agg.join(df_tmp1, "id", "left") # df_agg.show(truncate=False) print("success")