import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import functools
from utils.db_util import DbTypes
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from utils.redis_utils import RedisUtils
from utils.DorisHelper import DorisHelper
from pyspark.sql.window import Window
from pyspark.sql import functions as F, DataFrame

# 待同步的六大站点
site_names = ['us', 'uk', 'it', 'de', 'es', 'fr']


def save_to_redis_list(iterator, redis_key, ttl: -1, batch: int):
    redis_cli = RedisUtils.get_redis_client_by_type(db_type='microservice')
    cnt = 0
    pipeline = redis_cli.pipeline()
    for json_row in iterator:
        pipeline.lpush(redis_key, json_row)
        cnt += 1
        if cnt > 0 and cnt % batch == 0:
            pipeline.execute()
    if cnt % batch != 0:
        pipeline.execute()
    pipeline.close()
    if ttl > 0:
        redis_cli.expire(redis_key, ttl)
    redis_cli.close()
    pass


def save_to_kafka(all_df: DataFrame):
    # df_kafka = all_df.repartition(10)
    kafka_target = {
        "kafka.bootstrap.servers": "218.17.154.146:27092,218.17.154.146:27093,218.17.154.146:27094",
        "kafka.security.protocol": "SASL_PLAINTEXT",
        "kafka.sasl.mechanism": "PLAIN",
        "kafka.sasl.jaas.config": "org.apache.kafka.common.security.plain.PlainLoginModule required username='producer' password='R8@xY3pL!qz';",
        "topic": "self_asin_detail",
    }
    all_df.selectExpr("CAST(concat(site,asin) AS STRING) AS key", "to_json(struct(*)) AS value") \
        .write \
        .format("kafka") \
        .options(**kafka_target) \
        .save()


def save_to_doris(df_all: DataFrame):
    df_all = df_all.selectExpr("""
    case when site = 'us' then 'Amazon.com'
    when site = 'uk' then 'Amazon.co.uk'
    when site = 'de' then 'Amazon.de'
    when site = 'fr' then 'Amazon.fr'
    when site = 'it' then 'Amazon.it'
    when site = 'es' then 'Amazon.es'
    when site = 'ca' then 'Amazon.ca'
    when site = 'jp' then 'Amazon.jp'
    when site = 'mx' then 'Amazon.com.mx'
    when site = 'nl' then 'Amazon.nl'
    when site = 'be' then 'Amazon.com.be'
    when site = 'tr' then 'Amazon.com.tr'
    when site = 'se' then 'Amazon.se'
    when site = 'pl' then 'Amazon.pl'
    when site = 'ae' then 'Amazon.ae'
    when site = 'au' then 'Amazon.com.au'
    else site end as site
    """,
                               "asin",
                               "account_name",
                               "rating",
                               "total_comments",
                               "volume",
                               "weight",
                               "category",
                               "`rank`",
                               "video_url",
                               "add_url",
                               "material",
                               "img_type",
                               "qa_num",
                               "brand",
                               "node_id",
                               "one_star",
                               "two_star",
                               "three_star",
                               "four_star",
                               "five_star",
                               "low_star",
                               "asin_type",
                               "is_coupon",
                               "other_seller_name",
                               "buy_sales",
                               "updated_at",
                               "img_num"
                               )
    write_fields = ",".join(df_all.schema.fieldNames())

    connection_info = DorisHelper.get_connection_info("adv")
    options = {
        "doris.fenodes": f"{connection_info['ip']}:{connection_info['http_port']}",
        "user": connection_info['user'],
        "password": connection_info['pwd'],
        # "doris.table.identifier": "advertising_manager_test.test_doris",
        "doris.table.identifier": "advertising_manager.sync_amazon_item_day",
        # 此处字段顺序要固定
        "doris.write.fields": write_fields,
        # 部分列更新
        "doris.sink.properties.partial_columns": "true",
        "doris.sink.properties.format": "json"
    }
    df_all.write.format("doris") \
        .options(**options) \
        .mode("append") \
        .save()


def export():
    spark = SparkUtil.get_spark_session("self_asin_redis:export")
    redis_key = f"self_asin_detail:2024-11-25"
    for site_name in site_names:
        query = f"""
        select asin,    
               coalesce(site, '{site_name}') as site,
               coalesce(rating, 0) as rating,
               total_comments,
               volume,
               round(weight,4) as weight,
               category,
               `rank`,
               video_url,
               add_url,
               material,
               img_type,
               qa_num,
               brand,
               node_id,
               one_star,
               two_star,
               three_star,
               four_star,
               five_star,
               low_star,
               asin_type,
               is_coupon,
               account_name,
               other_seller_name,
               buy_sales,
               img_num,
               date_format(updated_at, '%Y-%m-%d %H:%m:%S') updated_at
        from {site_name}_self_asin_detail
        where updated_at >= '2024-11-22'
          and updated_at <= '2024-11-26'
            """
        asin_df = SparkUtil.read_jdbc(spark, DbTypes.mysql.name, site_name, query=query)
        # 此处需要根据时间开窗取最新的那个
        asin_df = asin_df.withColumn("row_number",
                                     F.row_number().over(
                                         window=Window.partitionBy(['site', 'asin']).orderBy(F.col("updated_at").desc()))) \
            .where("row_number == 1") \
            .drop("row_number")

        #  填充默认值
        asin_df = na_fill(asin_df).cache()
        asin_df.toJSON().foreachPartition(functools.partial(save_to_redis_list, batch=5000, redis_key=redis_key, ttl=3600 * 24))
        print(f"{site_name}:redis:success")
        print("success all")
    pass


def na_fill(asin_df):
    return asin_df.na.fill({
        "rating": 0,
        "total_comments": 0,
        "volume": "",
        "weight": 0,
        "category": "",
        "rank": 0,
        "video_url": "",
        "add_url": "",
        "material": "",
        "img_type": 0,
        "qa_num": 0,
        "brand": "",
        "node_id": "",
        "one_star": 0,
        "two_star": 0,
        "three_star": 0,
        "four_star": 0,
        "five_star": 0,
        "low_star": 0,
        "asin_type": 0,
        "is_coupon": 0,
        "account_name": "",
        "other_seller_name": "",
        "buy_sales": "",
        "img_num": 0
    })
    pass


if __name__ == '__main__':
    export()