import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
import functools
from utils.db_util import DbTypes
from utils.common_util import CommonUtil
from utils.spark_util import SparkUtil
from utils.redis_utils import RedisUtils
from utils.DorisHelper import DorisHelper
from pyspark.sql.window import Window
from pyspark.sql import functions as F, DataFrame

# 待同步的六大站点
site_names = ['us', 'uk', 'it', 'de', 'es', 'fr']


def save_to_redis_list(iterator, redis_key, ttl: -1, batch: int):
    try:
        redis_cli = RedisUtils.get_redis_client_by_type(db_type='microservice')
        cnt = 0
        pipeline = redis_cli.pipeline()
        for json_row in iterator:
            pipeline.lpush(redis_key, json_row)
            cnt += 1
            if cnt > 0 and cnt % batch == 0:
                pipeline.execute()
        if cnt % batch != 0:
            pipeline.execute()
        pipeline.close()
        if ttl > 0:
            redis_cli.expire(redis_key, ttl)
        redis_cli.close()
    except:
        pass


def save_to_kafka(all_df: DataFrame):
    # df_kafka = all_df.repartition(10)
    kafka_target = {
        "kafka.bootstrap.servers": "218.17.154.146:27092,218.17.154.146:27093,218.17.154.146:27094",
        "kafka.security.protocol": "SASL_PLAINTEXT",
        "kafka.sasl.mechanism": "PLAIN",
        "kafka.sasl.jaas.config": "org.apache.kafka.common.security.plain.PlainLoginModule required username='producer' password='R8@xY3pL!qz';",
        "topic": "self_asin_detail",
    }
    all_df.selectExpr("CAST(concat(site,asin) AS STRING) AS key", "to_json(struct(*)) AS value") \
        .write \
        .format("kafka") \
        .options(**kafka_target) \
        .save()


def save_to_doris(df_all: DataFrame):
    df_all = df_all.selectExpr("""
    case when site = 'us' then 'Amazon.com'
    when site = 'uk' then 'Amazon.co.uk'
    when site = 'de' then 'Amazon.de'
    when site = 'fr' then 'Amazon.fr'
    when site = 'it' then 'Amazon.it'
    when site = 'es' then 'Amazon.es'
    when site = 'ca' then 'Amazon.ca'
    when site = 'jp' then 'Amazon.jp'
    when site = 'mx' then 'Amazon.com.mx'
    when site = 'nl' then 'Amazon.nl'
    when site = 'be' then 'Amazon.com.be'
    when site = 'tr' then 'Amazon.com.tr'
    when site = 'se' then 'Amazon.se'
    when site = 'pl' then 'Amazon.pl'
    when site = 'ae' then 'Amazon.ae'
    when site = 'au' then 'Amazon.com.au'
    else site end as site
    """,
                               "asin",
                               "account_name",
                               "rating",
                               "total_comments",
                               "volume",
                               "weight",
                               "category",
                               "`rank`",
                               "video_url",
                               "add_url",
                               "material",
                               "img_type",
                               "qa_num",
                               "brand",
                               "node_id",
                               "one_star",
                               "two_star",
                               "three_star",
                               "four_star",
                               "five_star",
                               "low_star",
                               "asin_type",
                               "is_coupon",
                               "other_seller_name",
                               "buy_sales",
                               "updated_at",
                               "img_num"
                               )
    write_fields = ",".join(df_all.schema.fieldNames())

    connection_info = DorisHelper.get_connection_info("adv")
    options = {
        "doris.fenodes": f"{connection_info['ip']}:{connection_info['http_port']}",
        "user": connection_info['user'],
        "password": connection_info['pwd'],
        # "doris.table.identifier": "advertising_manager_test.test_doris",
        "doris.table.identifier": "advertising_manager.sync_amazon_item_day",
        # 此处字段顺序要固定
        "doris.write.fields": write_fields,
        # 部分列更新
        "doris.sink.properties.partial_columns": "true",
        "doris.sink.properties.format": "json"
    }
    df_all.write.format("doris") \
        .options(**options) \
        .mode("append") \
        .save()


def export():
    spark = SparkUtil.get_spark_session("self_asin_redis:export")
    day = CommonUtil.get_sys_arg(1, CommonUtil.format_now("%Y-%m-%d"))
    last_day = CommonUtil.get_day_offset(day, -1)
    next_day = CommonUtil.get_day_offset(day, 1)
    # 先删除
    redis_key = f"self_asin_detail:{day}"
    client = RedisUtils.get_redis_client_by_type(db_type='microservice')
    if client.exists(redis_key):
        client.delete(redis_key)
    client.close()

    for site_name in site_names:
        query = f"""
        select asin,    
               coalesce(site, '{site_name}') as site,
               coalesce(rating, 0) as rating,
               total_comments,
               volume,
               round(weight,4) as weight,
               category,
               `rank`,
               video_url,
               add_url,
               material,
               img_type,
               qa_num,
               brand,
               node_id,
               one_star,
               two_star,
               three_star,
               four_star,
               five_star,
               low_star,
               asin_type,
               is_coupon,
               account_name,
               other_seller_name,
               buy_sales,
               img_num,
               date_format(updated_at, '%Y-%m-%d %H:%m:%S') updated_at
        from {site_name}_self_asin_detail
        where updated_at >= '{last_day}'
          and updated_at <= '{next_day}'
            """
        print(query)
        asin_df = SparkUtil.read_jdbc(spark, DbTypes.mysql.name, site_name, query=query)
        # 此处需要根据时间开窗取最新的那个
        asin_df = asin_df.withColumn("row_number",
                                     F.row_number().over(
                                         window=Window.partitionBy(['site', 'asin']).orderBy(F.col("updated_at").desc()))) \
            .where("row_number == 1") \
            .drop("row_number")

        #  填充默认值
        asin_df = na_fill(asin_df).cache()
        asin_df.toJSON().foreachPartition(functools.partial(save_to_redis_list, batch=5000, redis_key=redis_key, ttl=3600 * 24))
        print(f"{site_name}:redis:success")
        save_to_doris(asin_df)
        print(f"{site_name}:doris:success")
        print("success all")
    check_total()
    pass


def na_fill(asin_df):
    return asin_df.na.fill({
        "rating": 0,
        "total_comments": 0,
        "volume": "",
        "weight": 0,
        "category": "",
        "rank": 0,
        "video_url": "",
        "add_url": "",
        "material": "",
        "img_type": 0,
        "qa_num": 0,
        "brand": "",
        "node_id": "",
        "one_star": 0,
        "two_star": 0,
        "three_star": 0,
        "four_star": 0,
        "five_star": 0,
        "low_star": 0,
        "asin_type": 0,
        "is_coupon": 0,
        "account_name": "",
        "other_seller_name": "",
        "buy_sales": "",
        "img_num": 0
    })
    pass


def check_total():
    day = CommonUtil.get_sys_arg(1, CommonUtil.format_now("%Y-%m-%d"))
    redis_key = f"self_asin_detail:{day}"
    days = []
    for i in range(0, 10):
        days.append(CommonUtil.get_day_offset(day, -i))
        pass

    redis_cli = RedisUtils.get_redis_client_by_type(db_type='microservice', decode_responses=True)
    count_now = redis_cli.llen(redis_key)
    stat_key = "self_asin_detail:stat"
    redis_cli.hset(stat_key, day, count_now)

    import numpy as np
    # 查询10天平均值
    avg = np.average(np.array(
        list(map(lambda x: int(x), list(filter(lambda x: x is not None and int(x) > 0, redis_cli.hmget(stat_key, days)))))
    ))
    if count_now < avg * 0.5 or count_now > avg * 1.5:
        CommonUtil.send_wx_msg(['wujicang', 'leichao', 'hezhe'], title='数据同步警告',
                               content=f"内部asin同步队列【{redis_key}】数据总数为{count_now},近10天平均数为{int(avg)},请检查数据是否异常!!")
        pass

    redis_cli.close()
    pass


def export_all():
    spark = SparkUtil.get_spark_session("self_asin_redis:export")
    day = CommonUtil.get_sys_arg(1, CommonUtil.format_now("%Y-%m-%d"))
    day10_before = CommonUtil.get_day_offset(day, -10)
    # 先删除
    redis_key = f"self_asin_detail:all_lastest"
    client = RedisUtils.get_redis_client_by_type(db_type='microservice')
    if client.exists(redis_key):
        client.delete(redis_key)
    client.close()

    for site_name in site_names:
        query = f"""
            select asin,
                   coalesce(site, '{site_name}') as site,
                   coalesce(rating, 0) as rating,
                   total_comments,
                   volume,
                   weight,
                   category,
                   `rank`,
                   video_url,
                   add_url,
                   material,
                   img_type,
                   qa_num,
                   brand,
                   node_id,
                   one_star,
                   two_star,
                   three_star,
                   four_star,
                   five_star,
                   low_star,
                   asin_type,
                   is_coupon,
                   account_name,
                   other_seller_name,
                   buy_sales,
                   img_num,
                   date_format(updated_at, '%Y-%m-%d %H:%m:%S') updated_at
            from (
                     select max(id) as max_id
                     from {site_name}_self_asin_detail
                     group by asin
                 ) tmp1
                     inner join {site_name}_self_asin_detail tmp2 on tmp1.max_id = tmp2.id
            """
        asin_df = SparkUtil.read_jdbc(spark, DbTypes.mysql.name, site_name, query=query)
        #  填充默认值
        asin_df = na_fill(asin_df)
        asin_df.toJSON().foreachPartition(functools.partial(save_to_redis_list, batch=1000, redis_key=redis_key, ttl=3600 * 24 * 7))
        print(f"{site_name}:success")
    print("success all")
    pass


if __name__ == '__main__':
    export()