test_pivot.py 1.22 KB
import os
import sys

from utils.templates import Templates
from pyspark.sql import functions as F

sys.path.append(os.path.dirname(sys.path[0]))

if __name__ == '__main__':
    tmp = Templates()
    spark = tmp.create_spark_object("test_df")

    df_all = spark.createDataFrame(
        [
            ('2022', 3, "name1", 1),
            ('2023', 2, "name2", 1),
            ('2022', 1, "name3", 1),
            ('2023', 2, "name4", 1),
        ],
        ("day", 'id', 'label', "age")
    )

    df_all.show(n=100, truncate=False)

    # df_save = df_save.withColumn("tmpArr", F.array(list(map(lambda item: F.lit(item), [1, 2, 3, 4, 5]))))
    # df_save = df_save.withColumn("tmpStruct", F.struct(F.col("label"), F.col("age")))

    arr = ['2022', '2023', '2024', '2025', '3033']

    # df_agg = df_all.groupBy("id").agg(F.first("label"))

    df_tmp1 = df_all.groupBy("id") \
        .pivot("day", arr).agg(
        F.first("age").alias("age"),
        F.first("label").alias("label")
    )

    df_tmp1.show(truncate=False)
    # for index in range(0, len(arr)):
    #     df_tmp1 = df_tmp1.withColumnRenamed(arr[index], f"age{index + 1}")

    # df_agg = df_agg.join(df_tmp1, "id", "left")

    # df_agg.show(truncate=False)

    print("success")