import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.spark_util import SparkUtil
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from utils.db_util import DBUtil, DbTypes
from pyspark.sql import functions as F, Window
from pyspark.sql import DataFrame
from pyspark.sql.types import IntegerType


def handle_listen(checkpoint):
    topic_name = "keyword_pcp_prod"
    options = {
        "kafka.sasl.jaas.config": """org.apache.kafka.common.security.plain.PlainLoginModule required username="consumer" password="J2#aLmPq7zX";""",
        "kafka.sasl.mechanism": "PLAIN",
        "kafka.security.protocol": "SASL_PLAINTEXT",
        "kafka.bootstrap.servers": "61.145.136.61:19092,61.145.136.61:29092,61.145.136.61:39092",
        "subscribe": topic_name,
        "failOnDataLoss": "false",
        "startingOffsets": "earliest",
        # "maxOffsetsPerTrigger": "500"
    }
    spark = SparkUtil.get_stream_spark("keyword_pcp_prod_consumer", use_lzo_compress=False)
    print(f"listen to topic {topic_name}")
    kafka_reader = spark.readStream.format("kafka").options(**options).load()
    kafka_reader = kafka_reader.select(
        F.col("timestamp").alias("timestamp"),
        F.col("value").cast("string").alias("value")
    )
    # 批处理 每 N 秒处理一次
    writing_sink = kafka_reader.writeStream \
        .option("checkpointLocation", f"/tmp/spark_kafka/checkpoint/{checkpoint}") \
        .foreachBatch(handle_foreach_save_hive) \
        .trigger(processingTime=f'100 seconds').start()

    writing_sink.awaitTermination()
    spark.stop()
    pass


# def handle_foreach_batch(df: DataFrame, batch_id):
#     print(f"handle batch_id {batch_id}")
#     schema_of_json = F.schema_of_json("""
#     [
#         {
#             "adGroupId": "399165404207442",
#             "keywordId": "360139837936712",
#             "maxBid": 3.78,
#             "minBid": 1.93,
#             "suggestedBid": 2.57,
#             "type": "KEYWORD_BROAD_MATCH",
#             "value": "water bottles bulk"
#         }
#     ]
# """)
#     df_save = df.select(
#         F.explode(F.from_json(F.col("value"), schema_of_json)).alias("json"),
#         F.col("timestamp")
#     ) \
#         .filter(F.col("json").isNotNull()) \
#         .select(
#         F.col("json.adGroupId").alias("group_id"),
#         F.col("json.keywordId").alias("keyword_id"),
#         F.col("json.value").alias("keyword"),
#         F.col("json.type").alias("match_type"),
#         F.col('timestamp').alias("created_at"),
#         F.col("json.minBid").alias("min_bid"),
#         F.col("json.maxBid").alias("max_bid"),
#         F.col("json.suggestedBid").alias("suggested_bid"),
#     )
#     # conn_info = DBUtil.get_connection_info("mysql", "us")
#     conn_info = DBUtil.get_connection_info(DbTypes.mysql.name, "us")
#     count = df_save.count()
#     print(f"save success count = {count}")
#     sys.exit()
#     # 设置 batchsize 5000 和 重写rewriteBatchedStatements=true 加快插入速度
#     # df_save.write.format('jdbc').options(
#     #     url=conn_info['url'] + "?rewriteBatchedStatements=true",
#     #     user=conn_info['username'],
#     #     password=conn_info['pwd'],
#     #     dbtable='st_pcp_history',
#     #     batchsize="5000"
#     # ).mode('append').save()
#     pass


def handle_foreach_save_hive(df: DataFrame, batch_id):
    print(f"handle batch_id {batch_id} df.count {df.count()}")
    schema_of_json = F.schema_of_json("""
    [
        {
            "adGroupId": "399165404207442",
            "keywordId": "360139837936712",
            "maxBid": 3.78,
            "minBid": 1.93,
            "siteId": "1",
            "suggestedBid": 2.57,
            "type": "KEYWORD_BROAD_MATCH",
            "value": "water bottles bulk"
        }
    ]
""")
    df_save = df.select(
        F.explode(F.from_json(F.col("value"), schema_of_json)).alias("json"),
        F.col("timestamp")
    ) \
        .filter(F.col("json").isNotNull()) \
        .select(
        F.col("json.siteId").cast(IntegerType()).alias("site_id"),
        F.col("json.adGroupId").alias("group_id"),
        F.col("json.keywordId").alias("keyword_id"),
        F.col("json.value").alias("keyword"),
        F.col("json.type").alias("match_type"),
        F.date_format(F.col('timestamp'), 'yyyy-MM-dd HH:mm:ss').alias('created_at'),
        F.col("json.minBid").alias("min_bid"),
        F.col("json.maxBid").alias("max_bid"),
        F.col("json.suggestedBid").alias("suggested_bid"),
        F.date_format(F.col('timestamp'), 'yyyy-MM-dd').alias('date_info'),
    )
    df_save = df_save.filter(F.col("date_info").isNotNull())
    df_save = df_save.repartition(1)
    count = df_save.count()
    if count == 0:
        return
    hive_tb = 'dim_st_pcp_history'
    df_save.write.saveAsTable(name=hive_tb, format='hive', mode='append', partitionBy=['date_info'])
    print(f"save success count = {count}")

    now_date = CommonUtil.format_now("%Y-%m-%d")
    # 保证分区文件<=20
    CommonUtil.orctable_concatenate(
        hive_table=hive_tb,
        partition_dict={
            "date_info": now_date
        },
        innerFlag=False,
        min_part_num=20,
        max_retry_time=5
    )

    # # 每5批次进行分区小文件的合并操作 只允许orc格式的建表格式
    # if batch_id > 0 and batch_id % 10 == 0:
    #     print(f"current batch_id {batch_id}")
    #     now_date = CommonUtil.format_now("%Y-%m-%d")
    #     path_exist = HdfsUtils.path_exist(
    #         CommonUtil.build_hdfs_path(hive_tb, partition_dict={"date_info": now_date})
    #     )
    #     if path_exist:
    #         # 必须先修复分区后再聚合碎片
    #         CommonUtil.hive_cmd_exec(f"""msck repair table big_data_selection.{hive_tb};""")
    #         CommonUtil.hive_cmd_exec(f"""alter table big_data_selection.{hive_tb} partition (date_info = "{now_date}") concatenate;""")
    #     pass
    # pass


if __name__ == '__main__':
    # 启动脚本
    """
    /opt/module/spark/bin/spark-submit  \
    --packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.1.3 \
    --master yarn \
    --driver-memory 1g \
    --executor-memory 1g \
    --executor-cores 1 \
    --num-executors 1 \
    --queue default \
    /opt/module/spark/demo/py_demo/my_kafka/keyword_pcp_listener.py
    /tmp/wjc_py/my_kafka/keyword_pcp_listener.py
    """
    handle_listen("keyword_pcp_prod")
