import os
import sys
import pymysql


sys.path.append(os.path.dirname(sys.path[0]))
from utils.common_util import CommonUtil
from pyspark.sql import SparkSession
from utils.adv_doris_stream_load import stream_load


def get_db_engine():
    # Doris连接参数
    con = {
        "host": "192.168.10.218",
        "port": 9030,
        "user": "root",
        "password": "",
        "database": "adv"
    }
    connection = pymysql.connect(**con)
    return connection


def exec_sql(connection):

    cursor = connection.cursor()
    spark = SparkSession.builder.appName("adv:sp_customer_search_target_report")\
        .config("spark.network.timeout", 10000000)\
        .getOrCreate()

    sql1 = """
    TRUNCATE table sp_customer_search_target_report_copy;
    """
    print("----------------------------------------------------")
    print(sql1)
    cursor.execute(sql1)
    print("清空copy表完成")

    sql2 = """
    (SELECT userName, campaignId FROM mysql_adv.advertising_manager_us.user_campaign_id GROUP BY userName, campaignId) as t
    """
    print("----------------------------------------------------")
    print(sql2)

    sql3 = """
    (SELECT
        t1.targetId AS targetId,
        t2.adGroupId AS adGroupId
    FROM
        mysql_adv.advertising_manager_us.us_product_target_report t1
    LEFT JOIN mysql_adv.advertising_manager_us.product_manual_target t2
    ON t1.targetId = t2.targetId
    WHERE t1.targetingExpression LIKE '%%asin=%%') as t
    """
    print("----------------------------------------------------")
    print(sql3)

    sql4 = """
    (SELECT
        t3.targetId AS targetId,
        t4.adGroupId AS adGroupId
    FROM
        mysql_adv.advertising_manager_us.us_product_target_report t3
    LEFT JOIN mysql_adv.advertising_manager_us.product_target t4
    ON t3.targetId = t4.targetId
    WHERE t3.targetingExpression NOT LIKE '%%asin=%%'
    AND t3.targetingExpression NOT LIKE '%%asin-expanded=%%') as t
    """
    print("----------------------------------------------------")
    print(sql4)

    df_b = spark.read.format("jdbc").options(
        url="jdbc:mysql://192.168.10.218:9030/adv",
        driver="com.mysql.jdbc.Driver",
        dbtable=sql2,
        user="root",
        password=""
    ).load().cache()
    df_c_asin = spark.read.format("jdbc").options(
        url="jdbc:mysql://192.168.10.218:9030/adv",
        driver="com.mysql.jdbc.Driver",
        dbtable=sql3,
        user="root",
        password=""
    ).load()
    df_c_not_asin = spark.read.format("jdbc").options(
        url="jdbc:mysql://192.168.10.218:9030/adv",
        driver="com.mysql.jdbc.Driver",
        dbtable=sql4,
        user="root",
        password=""
    ).load()
    df_c = df_c_asin.unionByName(df_c_not_asin).cache()

    # df_b = pd.read_sql_query(sql2, conn)
    # df_c_asin = pd.read_sql_query(sql3, conn)
    # df_c_not_asin = pd.read_sql_query(sql4, conn)
    # df_c = pd.concat([df_c_asin, df_c_not_asin])
    # result_df_list = []
    # for start_date in range(20200000, 20240000, 10000):
    query = f'''
    (SELECT
        startDate,
        site,
        impressions,
        clicks,
        cost,
        updated_at,
        created_at,
        account,
        attributedConversions7d,
        attributedSales7d,
        targetingExpression,
        `query`,
        targetingText,
        targetId,
        campaignId,
        id,
        adGroupId,
        adGroupName
    FROM
        mysql_adv.advertising_manager_us.us_customer_search_target_report 
    WHERE
        startDate >= 20200000) as t
    '''
    print("----------------------------------------------------")
    print(query)
    df_a = spark.read.format("jdbc").options(
        url="jdbc:mysql://192.168.10.218:9030/adv",
        driver="com.mysql.jdbc.Driver",
        dbtable=query,
        user="root",
        password=""
    ).load().cache()
    # df_a = pd.read_sql_query(query, conn)
    print("----------------------------------------------------")
    print("该分区读取完成")
    # 将 a, b, c 关联在一起
    # merged_df = pd.merge(df_a, df_b, on='campaignId', how='left')
    # merged_df = pd.merge(merged_df, df_c, on='targetId', how='left')
    merged_df = df_a.join(df_b, 'campaignId', 'left').join(df_c, 'targetId', 'left').cache()
    # result_df_list.append(merged_df)
    print("----------------------------------------------------")
    print("该分区join完成")
    # df_save = pd.concat(result_df_list)
    # print("----------------------------------------------------")
    # print("全部表union完成")

    # 将 DataFrame 中的数据收集到一个列表中
    data_list = merged_df.collect()
    # 将列表转换为以指定分隔符分隔的字符串
    delimiter = ","
    result_string = "\n".join([delimiter.join(map(str, row)) for row in data_list])

    # df_save.to_sql('sp_customer_search_target_report_copy', conn, if_exists='append', index=False)
    # 将 DataFrame 写入到 MySQL 表中
    stream_load(result_string)
    print("写入copy表完成")
    # 交换表名SQL
    sql5 = """
    ALTER TABLE sp_customer_search_target_report_copy RENAME sp_customer_search_target_report_tmp;

    ALTER TABLE sp_customer_search_target_report RENAME sp_customer_search_target_report_copy;

    ALTER TABLE sp_customer_search_target_report_tmp RENAME sp_customer_search_target_report;
    """
    print("----------------------------------------------------")
    print(sql5)
    cursor.execute(sql5)
    print("交换表名完成")
    # 优化查询计划SQL
    sql6 = """
    ANALYZE TABLE sp_customer_search_target_report;
    """
    print("----------------------------------------------------")
    print(sql6)
    cursor.execute(sql6)
    print("ANALYZE完成")

    cursor.close()
    connection.close()

if __name__ == '__main__':

    # 获取数据库engine
    connection = get_db_engine()
    print("----------------------------------------------------")
    print("成功获取Doris的engine")

    # 通过engine获取数据库连接，执行SQL
    exec_sql(connection)

    # 完成通知
    CommonUtil.send_wx_msg(["chenyuanjie"], "【adv:sp_customer_search_target_report导入成功】", "悉知")
