import os import sys sys.path.append(os.path.dirname(sys.path[0])) from pyspark.sql import SparkSession from utils.db_util import DBUtil class SparkUtil(object): """ 默认连接库 """ DEF_USE_DB = "big_data_selection" """ spark相关工具类 """ __home_path__ = "hdfs://nameservice1:8020/home" __meta_store_uri__ = "thrift://hadoop16:9083" current_spark = None @classmethod def get_spark_session(cls, app_name: str = "test", use_db: str = "big_data_selection", use_lzo_compress: bool = True) -> SparkSession: """ :param app_name: :param use_db: :param use_lzo_compress: 是否输出save的时候启用lzo压缩 默认是true :return: """ spark = SparkSession.builder. \ appName(f"{app_name}"). \ config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \ config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \ config("spark.network.timeout", 1000000). \ config("spark.sql.orc.mergeSchema", True). \ config("spark.sql.parquet.compression.codec", "lzo"). \ config("spark.driver.maxResultSize", "10g"). \ config("spark.sql.autoBroadcastJoinThreshold", -1). \ enableHiveSupport(). \ getOrCreate() # 添加环境 spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip") spark.sql("set hive.exec.dynamic.partition.mode=nonstrict") if use_lzo_compress: spark.sql('''set mapred.output.compress=true''') spark.sql('''set hive.exec.compress.output=true''') spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''') spark.sql(f"use {use_db};") cls.current_spark = spark return spark @classmethod def get_spark_sessionV2(cls, app_name: str = "test", use_db: str = "big_data_selection"): spark = SparkSession.builder. \ appName(f"{app_name}"). \ config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \ config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \ config("spark.network.timeout", 1000000). \ config("spark.driver.maxResultSize", "10g"). \ enableHiveSupport(). \ getOrCreate() # 加入yswg_utils.zip spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip") spark.sql(f"use {use_db};") return spark @classmethod def get_spark_sessionV3(cls, app_name: str = "test", use_db: str = "big_data_selection"): spark = SparkSession.builder. \ master("yarn"). \ appName(f"{app_name}"). \ config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \ config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \ config("spark.network.timeout", 1000000). \ config("spark.driver.maxResultSize", "10g"). \ config("spark.driver.memory", "5g"). \ config("spark.executor.memory", "8g"). \ config("spark.executor.cores", "4"). \ config("spark.yarn.queue", "spark"). \ enableHiveSupport(). \ getOrCreate() # 加入yswg_utils.zip spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip") spark.sql('''set hive.compute.query.using.stats = false''') spark.sql(f"use {use_db};") return spark @classmethod def get_spark_sessionV4(cls, app_name: str = "test", use_db: str = "big_data_selection", use_lzo_compress: bool = True): spark = SparkSession.builder. \ appName(f"{app_name}"). \ config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \ config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \ config("spark.network.timeout", 1000000). \ config("spark.sql.orc.mergeSchema", True). \ config("spark.sql.parquet.compression.codec", "lzo"). \ config("spark.driver.maxResultSize", "10g"). \ config("spark.sql.autoBroadcastJoinThreshold", -1). \ config("spark.sql.shuffle.partitions", 100). \ config("spark.shuffle.memoryFraction", "0.4"). \ config("spark.shuffle.spill.compress", "true"). \ config("spark.shuffle.file.buffer", "5m"). \ config("spark.streaming.stopGracefullyOnShutdown", "true"). \ enableHiveSupport(). \ getOrCreate() # 添加环境 spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip") spark.sql("set hive.exec.dynamic.partition.mode=nonstrict") if use_lzo_compress: spark.sql('''set mapred.output.compress=true''') spark.sql('''set hive.exec.compress.output=true''') spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''') spark.sql(f"use {use_db};") cls.current_spark = spark return spark @classmethod def get_stream_spark(cls, app_name: str, use_db: str="big_data_selection", use_lzo_compress: bool=True): spark = SparkSession.builder.\ appName(f"{app_name}"). \ config("spark.sql.warehouse.dir", f"{SparkUtil.__home_path__}/{use_db}"). \ config("spark.metastore.uris", SparkUtil.__meta_store_uri__). \ config("spark.network.timeout", 1000000). \ config("spark.sql.orc.mergeSchema", True). \ config("spark.sql.parquet.compression.codec", "lzo"). \ config("spark.driver.maxResultSize", "10g"). \ config("spark.sql.autoBroadcastJoinThreshold", -1). \ config("spark.sql.shuffle.partitions", 100). \ config("spark.shuffle.memoryFraction", "0.4"). \ config("spark.shuffle.spill.compress", "true"). \ config("spark.shuffle.file.buffer", "5m"). \ config("spark.memory.fraction", "0.8"). \ config("spark.memory.storageFraction", "0.2"). \ config("spark.shuffle.memoryFraction", "0.2"). \ config("spark.yarn.am.nodeLabelExpression", "stream") .\ config("spark.yarn.executor.nodeLabelExpression", "stream"). \ config("spark.executor.heartbeatInterval", "60s"). \ config("spark.storage.blockManagerSlaveTimeoutM", "600s"). \ config("spark.memory.offHeap.enabled", "true"). \ config("spark.memory.offHeap.size", "8g"). \ config("spark.memory.fraction", "0.7"). \ config("spark.memory.storageFraction", "0.6"). \ enableHiveSupport(). \ getOrCreate() # 添加环境 spark.sparkContext.addPyFile(path="hdfs://nameservice1:8020/lib/yswg_utils.zip") spark.sql("set hive.exec.dynamic.partition.mode=nonstrict") if use_lzo_compress: spark.sql('''set mapred.output.compress=true''') spark.sql('''set hive.exec.compress.output=true''') spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''') spark.sql(f"use {use_db};") cls.current_spark = spark return spark @staticmethod def read_jdbc(session: SparkSession, dbtype: str, site_name: str, query: str): conn_info = DBUtil.get_connection_info(dbtype, site_name) return SparkUtil.read_jdbc_query( session=session, url=conn_info["url"], pwd=conn_info["pwd"], username=conn_info["username"], query=query ) @staticmethod def read_jdbc_query(session: SparkSession, url: str, pwd: str, username: str, query: str): """ 直接读取jdbc数据 :param session: :param url: :param pwd: :param username: :param query: :return: """ # query 末尾不能是分号 query = query.strip() assert not query.endswith(";"), "sql 末尾不能带有分号,请检查!!!" return session.read.format("jdbc") \ .option("url", url) \ .option("user", username) \ .option("password", pwd) \ .option("query", query) \ .load()