1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import sys
from utils.templates import Templates
from pyspark.sql import functions as F
sys.path.append(os.path.dirname(sys.path[0]))
if __name__ == '__main__':
tmp = Templates()
spark = tmp.create_spark_object("test_df")
df_all = spark.createDataFrame(
[
('2022', 3, "name1", 1),
('2023', 2, "name2", 1),
('2022', 1, "name3", 1),
('2023', 2, "name4", 1),
],
("day", 'id', 'label', "age")
)
df_all.show(n=100, truncate=False)
# df_save = df_save.withColumn("tmpArr", F.array(list(map(lambda item: F.lit(item), [1, 2, 3, 4, 5]))))
# df_save = df_save.withColumn("tmpStruct", F.struct(F.col("label"), F.col("age")))
arr = ['2022', '2023', '2024', '2025', '3033']
# df_agg = df_all.groupBy("id").agg(F.first("label"))
df_tmp1 = df_all.groupBy("id") \
.pivot("day", arr).agg(
F.first("age").alias("age"),
F.first("label").alias("label")
)
df_tmp1.show(truncate=False)
# for index in range(0, len(arr)):
# df_tmp1 = df_tmp1.withColumnRenamed(arr[index], f"age{index + 1}")
# df_agg = df_agg.join(df_tmp1, "id", "left")
# df_agg.show(truncate=False)
print("success")