spark_util.py 8.3 KB
import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql import SparkSession
from utils.db_util import DBUtil


class SparkUtil(object):
    """
    默认连接库
    """
    DEF_USE_DB = "big_data_selection"

    """
    spark相关工具类
    """

    __home_path__ = "hdfs://nameservice1:8020/home"
    __meta_store_uri__ = "thrift://hadoop16:9083"
    current_spark = None

    @classmethod
    def get_spark_session(cls, app_name: str = "test",
                          use_db: str = "big_data_selection",
                          use_lzo_compress: bool = True) -> SparkSession:
        """
        :param app_name:
        :param use_db:
        :param use_lzo_compress:  是否输出save的时候启用lzo压缩 默认是true
        :return:
        """
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            enableHiveSupport(). \
            getOrCreate()
        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @classmethod
    def get_spark_sessionV2(cls, app_name: str = "test", use_db: str = "big_data_selection"):
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.driver.maxResultSize", "10g"). \
            enableHiveSupport(). \
            getOrCreate()
        # 加入yswg_utils.zip
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")
        spark.sql(f"use {use_db};")
        return spark

    @classmethod
    def get_spark_sessionV3(cls, app_name: str = "test", use_db: str = "big_data_selection"):
        spark = SparkSession.builder. \
            master("yarn"). \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.driver.memory", "5g"). \
            config("spark.executor.memory", "8g"). \
            config("spark.executor.cores", "4"). \
            config("spark.yarn.queue", "spark"). \
            enableHiveSupport(). \
            getOrCreate()
        # 加入yswg_utils.zip
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")
        spark.sql('''set hive.compute.query.using.stats = false''')
        spark.sql(f"use {use_db};")
        return spark

    @classmethod
    def get_spark_sessionV4(cls, app_name: str = "test", use_db: str = "big_data_selection", use_lzo_compress: bool = True):
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            config("spark.sql.shuffle.partitions", 100). \
            config("spark.shuffle.memoryFraction", "0.4"). \
            config("spark.shuffle.spill.compress", "true"). \
            config("spark.shuffle.file.buffer", "5m"). \
            config("spark.streaming.stopGracefullyOnShutdown", "true"). \
            enableHiveSupport(). \
            getOrCreate()
        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @classmethod
    def get_stream_spark(cls, app_name: str, use_db: str="big_data_selection", use_lzo_compress: bool=True):
        spark = SparkSession.builder.\
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            config("spark.sql.shuffle.partitions", 100). \
            config("spark.shuffle.memoryFraction", "0.4"). \
            config("spark.shuffle.spill.compress", "true"). \
            config("spark.shuffle.file.buffer", "5m"). \
            config("spark.memory.fraction", "0.8"). \
            config("spark.memory.storageFraction", "0.2"). \
            config("spark.shuffle.memoryFraction", "0.2"). \
            config("spark.yarn.am.nodeLabelExpression", "stream") .\
            config("spark.yarn.executor.nodeLabelExpression", "stream"). \
            config("spark.executor.heartbeatInterval", "60s"). \
            config("spark.storage.blockManagerSlaveTimeoutM", "600s"). \
            config("spark.memory.offHeap.enabled", "true"). \
            config("spark.memory.offHeap.size", "8g"). \
            config("spark.memory.fraction", "0.7"). \
            config("spark.memory.storageFraction", "0.6"). \
            enableHiveSupport(). \
            getOrCreate()

        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @staticmethod
    def read_jdbc(session: SparkSession, dbtype: str, site_name: str, query: str):
        conn_info = DBUtil.get_connection_info(dbtype, site_name)
        return SparkUtil.read_jdbc_query(
            session=session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=query
        )

    @staticmethod
    def read_jdbc_query(session: SparkSession, url: str, pwd: str, username: str, query: str):
        """
        直接读取jdbc数据
        :param session:
        :param url:
        :param pwd:
        :param username:
        :param query:
        :return:
        """
        #  query 末尾不能是分号
        query = query.strip()
        assert not query.endswith(";"), "sql 末尾不能带有分号,请检查!!!"
        return session.read.format("jdbc") \
            .option("url", url) \
            .option("user", username) \
            .option("password", pwd) \
            .option("query", query) \
            .load()