test_udf.py 1.6 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
import os
import sys
from pyspark.sql.functions import udf
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from pyspark.sql.types import MapType, StringType

from utils.db_conf import build_es_option
from utils.templates import Templates

sys.path.append(os.path.dirname(sys.path[0]))


@udf(returnType=MapType(StringType(), StringType()))
def test_udf(arr: list):
    """
    传入的是一个二维数组
    :return: 返回一个计算成功的Map
    """
    print(type(arr))

    for item in arr:
        date = item[0]
        year = item[1]
        print(date)
        print(year)

    return {
        "filed1": 1,
        "filed2": "filed2",
        "filed3": "filed3",
    }


if __name__ == '__main__':
    tmp = Templates()
    spark = tmp.create_spark_object("test_df")
    sql = """
            select id,
               date,
               year,
               quarter,
               month,
               week,
               day,
               week_day,
               year_month
        from big_data_selection.dim_date_20_to_30
        limit 10;
    """
    df = spark.sql(sql)
    # df.show(10)

    df = df.groupBy(["year_month"]) \
        .agg(F.collect_list(F.array([df['date'], df['year']])).alias("list"))

    df = df.withColumn("lastMap", test_udf(F.col("list")))

    df = df.select(
        "year_month",
        F.col("lastMap").getField("filed1").alias("filed1"),
        F.col("lastMap").getField("filed2").alias("filed2"),
        F.col("lastMap").getField("filed3").alias("filed3"),

    )
    df.show()

    print("success")