test_array.py 1.9 KB
import os
import sys

from pyspark.sql import functions as F
from pyspark.sql.types import MapType, StringType

from utils.common_util import CommonUtil
from utils.templates import Templates

sys.path.append(os.path.dirname(sys.path[0]))


def udf_col_12month(last_12_month, row_list):
    for month in last_12_month:
        print(month)

    for row in row_list:
        label = row['label']
        age = row['age']
        print(label)
        print(age)

    print(type(last_12_month))
    print(type(row_list))
    return {
        "1", "dahjo",
        "2", "dahjo",
        "3", "dahjo",
        "4", "dahjo",
    }
    pass


if __name__ == '__main__':
    tmp = Templates()
    spark = tmp.create_spark_object("test_df")
    udf_col_12month_reg = spark.udf.register("udf_col_12month", udf_col_12month,
                                             MapType(StringType(), StringType(), valueContainsNull=True))

    df_save = spark.createDataFrame(
        [
            (3, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1),
            (2, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1),
            (1, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1),
            (2, "Tools & Home Improvement›Power & Hand Tools›Hand Tools›Magnetic Sweepers", 1),
        ],
        ('id', 'label', "age")
    )

    df_save.show(n=100, truncate=False)

    df_save = df_save.withColumn("tmpArr", F.array(list(map(lambda item: F.lit(item), [1, 2, 3, 4, 5]))))
    df_save = df_save.withColumn("tmpStruct", F.struct(F.col("label"), F.col("age")))

    df_save = df_save.groupBy("id").agg(

        udf_col_12month_reg(
            CommonUtil.arr_to_spark_col([1, 2, 3, 4, 5]),
            F.collect_list(F.struct(F.col("label"), F.col("age")))
        ).alias("resultMap")

    )

    df_save.show(truncate=False)
    print("success")