import os
import sys

from sqlalchemy import create_engine
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType
from pyspark import pandas as ps
import pandas as pd
from collections import OrderedDict
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates


class EsBrandAnalytics(Templates):

    def __init__(self, site_name='us', date_type="week", date_info='2023-01'):
        self.site_name = site_name
        self.date_type = date_type
        self.date_info = date_info
        self.table_name = f"dwt_st_info"
        self.db_name = self.table_name
        self.year, self.week = date_info.split("-")
        # self.spark = self.create_spark_object(app_name=f"{self.db_name}: {self.site_name}, {self.date_type}, {self.date_info}")
        # self.get_date_info_tuple()
        self.es_table_name = f"{self.site_name}_brand_analytics_{self.date_type}".replace("week", f"{self.year}")
        if self.date_type == '4_week':
            self.es_table_name = f"{self.site_name}_brand_analytics_{self.date_type}"
        if self.site_name == 'us':
            self.engine = create_engine(
                f'mysql+pymysql://adv_yswg:HmRCMUjt03M33Lze@rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com:3306/selection?charset=utf8mb4')  # , pool_recycle=3600
            self.es_port = '9200'
        else:
            if self.site_name in ['uk', 'de']:
                self.es_port = '9201'
            else:
                self.es_port = '9202'
            self.engine = create_engine(
                f'mysql+pymysql://adv_yswg:HmRCMUjt03M33Lze@rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com:3306/selection_{self.site_name}?charset=utf8mb4')  # , pool_recycle=3600
        self.df_read = object()
        self.df_spark = object()
        # 配置es的连接对象
        self.es_url = '120.79.147.190'
        self.es_user = 'elastic'
        self.es_pass = 'selection2021.+'
        #self.es_resource = '{self.site_name}_test2/_doc'

        # 创建spark对象
        print(f"当前同步:{self.table_name}:, {self.site_name}-{self.year}-{self.week}")
        self.spark = SparkSession.builder. \
            appName(f"{self.table_name}:, {self.site_name}-{self.year}-{self.week}"). \
            config("spark.sql.warehouse.dir", f"hdfs://hadoop5:8020/home/big_data_selection"). \
            config("spark.metastore.uris", "thrift://hadoop4:9083"). \
            config("spark.network.timeout", 1000000). \
            config("spark.sql.parquet.compression.codec", "lzo"). \
            enableHiveSupport(). \
            getOrCreate()
        self.spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        self.spark.sql('''set mapred.output.compress=true''')
        self.spark.sql('''set hive.exec.compress.output=true''')
        self.spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        self.spark.sql(f"use big_data_selection;")
        self.partition_type = "dt"

    def read_data(self):
        sql = f"select * from {self.table_name} where site_name='{self.site_name}' and date_type='{self.date_type}' and date_info = '{self.date_info}';"
        print("sql:", sql)
        self.df_spark = self.spark.sql(sqlQuery=sql).cache()
        # self.df_spark = self.df_spark.cache()
        self.df_spark.show(10, truncate=False)
        #self.df_spark = self.df_spark.withColumn("created_time", self.df_spark.created_time.cast(TimestampType()))
        #self.df_spark = self.df_spark.withColumn("updated_time", self.df_spark.updated_time.cast(TimestampType()))
        self.df_spark = self.df_spark.dropDuplicates(["search_term"])
        print("self.df_spark.count:", self.df_spark.count())
        # self.df_spark = ps.from_pandas(self.df_spark).to_spark()
        print("分区数1:", self.df_spark.rdd.getNumPartitions())
        self.df_spark = self.df_spark.repartition(25)
        print("分区数2:", self.df_spark.rdd.getNumPartitions())

    def df_renamed(self):
        # self.df_spark = self.df_spark.withColumnRenamed("st_brand_id", "id")
        self.df_spark = self.df_spark.withColumnRenamed("st_rank", "rank")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin1", "asin1")
        # self.df_spark = self.df_spark.withColumnRenamed("st_product_title1", "product_title1")
        self.df_spark = self.df_spark.withColumnRenamed("st_click_share1", "click_share1")
        self.df_spark = self.df_spark.withColumnRenamed("st_conversion_share1", "conversion_share1")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin2", "asin2")
        # self.df_spark = self.df_spark.withColumnRenamed("st_product_title2", "product_title2")
        self.df_spark = self.df_spark.withColumnRenamed("st_click_share2", "click_share2")
        self.df_spark = self.df_spark.withColumnRenamed("st_conversion_share2", "conversion_share2")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin3", "asin3")
        # self.df_spark = self.df_spark.withColumnRenamed("st_product_title3", "product_title3")
        self.df_spark = self.df_spark.withColumnRenamed("st_click_share3", "click_share3")
        self.df_spark = self.df_spark.withColumnRenamed("st_conversion_share3", "conversion_share3")
        self.df_spark = self.df_spark.withColumnRenamed("st_click_share_sum", "click_share_sum")
        self.df_spark = self.df_spark.withColumnRenamed("st_conversion_share_sum", "conversion_share_sum")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_first_text", "is_first_text")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_ascending_text_rate", "is_ascending_text_rate")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_ascending_text", "is_ascending_text")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_search_text_rate", "is_search_text_rate")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_search_text", "is_search_text")
        self.df_spark = self.df_spark.withColumnRenamed("st_quantity_being_sold", "quantity_being_sold")
        self.df_spark = self.df_spark.withColumnRenamed("st_ao_val", "ao_val")
        self.df_spark = self.df_spark.withColumnRenamed("st_ao_val_rank", "ao_val_rank")
        self.df_spark = self.df_spark.withColumnRenamed("st_ao_val_rate", "ao_val_rate")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_bs_orders_sum", "bsr_orders")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_orders_sum", "asin_orders")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_counts", "asin_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_new_counts", "asin_new_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_new_counts_rate", "asin_new_counts_rate")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_new_orders_sum", "asin_new_orders")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin_new_orders_rate", "asin_new_orders_rate")
        self.df_spark = self.df_spark.withColumnRenamed("st_bsr_cate_1_id", "category_id")
        self.df_spark = self.df_spark.withColumnRenamed("st_bsr_cate_current_id", "category_current_id")
        self.df_spark = self.df_spark.withColumnRenamed("st_search_sum", "orders")
        self.df_spark = self.df_spark.withColumnRenamed("st_zr_counts", "zr_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_sp_counts", "sp_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_sb_counts", "sb_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_sb1_counts", "sb1_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_sb2_counts", "sb2_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_sb3_counts", "sb3_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_adv_counts", "adv_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_ac_counts", "ac_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_bs_counts", "bs_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_er_counts", "er_counts")
        self.df_spark = self.df_spark.withColumnRenamed("st_tr_counts", "tr_counts")
        self.df_spark = self.df_spark.withColumnRenamed("asin1_price", "price1")
        self.df_spark = self.df_spark.withColumnRenamed("asin1_rating", "rating1")
        self.df_spark = self.df_spark.withColumnRenamed("asin1_total_comments", "total_comments1")
        self.df_spark = self.df_spark.withColumnRenamed("st_asin1_bs_orders", "bs_orders1")
        self.df_spark = self.df_spark.withColumnRenamed("st_is_new_market_segment", "is_new_market_segment")
        # if self.date_type == '4_week':
        #     self.df_spark = self.df_spark.withColumn("dt", F.lit(f"{self.year}_{int(self.week)}"))
        self.df_spark = self.df_spark.withColumn("dt", F.lit(f"{self.year}_{int(self.week)}"))
        self.df_spark = self.df_spark.withColumn("id", F.concat(F.lit(f"{self.year}{int(self.week)}"), self.df_spark.rank))

    def save_data(self):
        # 将结果写入es
        options = OrderedDict()
        options['es.nodes'] = self.es_url
        options['es.port'] = self.es_port
        options['es.net.http.auth.user'] = self.es_user
        options['es.net.http.auth.pass'] = self.es_pass
        options['es.mapping.id'] = "id"
        # options['es.mapping.id'] = "search_term"
        options['es.resource'] = f'{self.es_table_name}/_doc'
        # 连接es的超时时间设置。默认1m
        # options['es.http.timeout'] = '10000m'
        options['es.nodes.wan.only'] = 'true'
        # # # 默认重试3次,为负值的话为无限重试(慎用)
        # # options['es.batch.write.retry.count'] = '15'
        # # 默认重试等待时间是 10s
        # options['es.batch.write.retry.wait'] = '60'
        # # 以下参数可以控制单次批量写入的数据量大小和条数(二选一)
        # options['es.batch.size.bytes'] = '20mb'
        # options['es.batch.size.entries'] = '20000'
        self.df_spark = self.df_spark.withColumn(
            "st_zr_page1_in_title_rate",
            F.round(self.df_spark.st_zr_page1_in_title_counts / self.df_spark.st_zr_page1_counts, 4)
        )
        self.df_spark = self.df_spark.drop("site_name", "st_asin_top1", "st_asin_top2", "st_asin_top3", "date_type", "date_info")
        print("self.df_spark.columns:", self.df_spark.columns)
        self.df_spark.write.format('org.elasticsearch.spark.sql').options(**options).mode('append').save()

    def connection_pg(self):
        PG_CONN_DICT = {
            "pg_port": 5432,
            "pg_db": "selection",
            "pg_user": "postgres",
            "pg_pwd": "fazAqRRVV9vDmwDNRNb593ht5TxYVrfTyHJSJ3BS",
            "pg_host": "192.168.10.216",
        }
        if self.site_name == 'us':
            db = 'selection'
        else:
            db = f'selection_{self.site_name}'
        self.engine_pg = create_engine(
            f"postgresql+psycopg2://{PG_CONN_DICT['pg_user']}:{PG_CONN_DICT['pg_pwd']}@{PG_CONN_DICT['pg_host']}:{PG_CONN_DICT['pg_port']}/{db}",
            encoding='utf-8')
        return self.engine_pg

    def save_data_to_pg(self):
        print("开始同步pg")
        self.connection_pg()
        self.df_spark = self.df_spark.drop("id")
        self.df_spark = self.df_spark.withColumn('dt', F.lit(self.date_info))
        df_save = self.df_spark.toPandas()
        df_save.to_sql(f"aba_year_{self.date_type}_{self.date_info.replace('-', '_')}_old", con=self.engine_pg, index=False, if_exists="append")

    def run(self):
        self.read_data()
        self.df_renamed()
        self.save_data()
        self.save_data_to_pg()


if __name__ == '__main__':
    site_name = sys.argv[1]  # 参数1:站点
    date_type = sys.argv[2]  # 参数2:week/month/quarter
    date_info = sys.argv[3]  # 参数2:week/month/quarter
    # handle_obj = EsBrandAnalytics(site_name=site_name, year=year)
    handle_obj = EsBrandAnalytics(site_name=site_name, date_type=date_type, date_info=date_info)
    handle_obj.run()