spark_util.py 8.3 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql import SparkSession
from utils.db_util import DBUtil


class SparkUtil(object):
    """
    默认连接库
    """
    DEF_USE_DB = "big_data_selection"

    """
    spark相关工具类
    """

    __home_path__ = "hdfs://nameservice1:8020/home"
    __meta_store_uri__ = "thrift://hadoop16:9083"
    current_spark = None

    @classmethod
    def get_spark_session(cls, app_name: str = "test",
                          use_db: str = "big_data_selection",
                          use_lzo_compress: bool = True) -> SparkSession:
        """
        :param app_name:
        :param use_db:
        :param use_lzo_compress:  是否输出save的时候启用lzo压缩 默认是true
        :return:
        """
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            enableHiveSupport(). \
            getOrCreate()
        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @classmethod
    def get_spark_sessionV2(cls, app_name: str = "test", use_db: str = "big_data_selection"):
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.driver.maxResultSize", "10g"). \
            enableHiveSupport(). \
            getOrCreate()
        # 加入yswg_utils.zip
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")
        spark.sql(f"use {use_db};")
        return spark

    @classmethod
    def get_spark_sessionV3(cls, app_name: str = "test", use_db: str = "big_data_selection"):
        spark = SparkSession.builder. \
            master("yarn"). \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.driver.memory", "5g"). \
            config("spark.executor.memory", "8g"). \
            config("spark.executor.cores", "4"). \
            config("spark.yarn.queue", "spark"). \
            enableHiveSupport(). \
            getOrCreate()
        # 加入yswg_utils.zip
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")
        spark.sql('''set hive.compute.query.using.stats = false''')
        spark.sql(f"use {use_db};")
        return spark

    @classmethod
    def get_spark_sessionV4(cls, app_name: str = "test", use_db: str = "big_data_selection", use_lzo_compress: bool = True):
        spark = SparkSession.builder. \
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            config("spark.sql.shuffle.partitions", 100). \
            config("spark.shuffle.memoryFraction", "0.4"). \
            config("spark.shuffle.spill.compress", "true"). \
            config("spark.shuffle.file.buffer", "5m"). \
            config("spark.streaming.stopGracefullyOnShutdown", "true"). \
            enableHiveSupport(). \
            getOrCreate()
        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @classmethod
    def get_stream_spark(cls, app_name: str, use_db: str="big_data_selection", use_lzo_compress: bool=True):
        spark = SparkSession.builder.\
            appName(f"{app_name}"). \
            config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \
            config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.orc.mergeSchema", True). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            config("spark.driver.maxResultSize", "10g"). \
            config("spark.sql.autoBroadcastJoinThreshold", -1). \
            config("spark.sql.shuffle.partitions", 100). \
            config("spark.shuffle.memoryFraction", "0.4"). \
            config("spark.shuffle.spill.compress", "true"). \
            config("spark.shuffle.file.buffer", "5m"). \
            config("spark.memory.fraction", "0.8"). \
            config("spark.memory.storageFraction", "0.2"). \
            config("spark.shuffle.memoryFraction", "0.2"). \
            config("spark.yarn.am.nodeLabelExpression", "stream") .\
            config("spark.yarn.executor.nodeLabelExpression", "stream"). \
            config("spark.executor.heartbeatInterval", "60s"). \
            config("spark.storage.blockManagerSlaveTimeoutM", "600s"). \
            config("spark.memory.offHeap.enabled", "true"). \
            config("spark.memory.offHeap.size", "8g"). \
            config("spark.memory.fraction", "0.7"). \
            config("spark.memory.storageFraction", "0.6"). \
            enableHiveSupport(). \
            getOrCreate()

        # 添加环境
        spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip")

        spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        if use_lzo_compress:
            spark.sql('''set mapred.output.compress=true''')
            spark.sql('''set hive.exec.compress.output=true''')
            spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        spark.sql(f"use {use_db};")
        cls.current_spark = spark
        return spark

    @staticmethod
    def read_jdbc(session: SparkSession, dbtype: str, site_name: str, query: str):
        conn_info = DBUtil.get_connection_info(dbtype, site_name)
        return SparkUtil.read_jdbc_query(
            session=session,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=query
        )

    @staticmethod
    def read_jdbc_query(session: SparkSession, url: str, pwd: str, username: str, query: str):
        """
        直接读取jdbc数据
        :param session:
        :param url:
        :param pwd:
        :param username:
        :param query:
        :return:
        """
        #  query 末尾不能是分号
        query = query.strip()
        assert not query.endswith(";"), "sql 末尾不能带有分号,请检查!!!"
        return session.read.format("jdbc") \
            .option("url", url) \
            .option("user", username) \
            .option("password", pwd) \
            .option("query", query) \
            .load()