import os
import sys
import traceback

sys.path.append(os.path.dirname(sys.path[0]))
import numpy as np
import pandas as pd
from sqlalchemy import create_engine
from yswg_utils.common_udf import parse_weight_str


class CleanWeight(object):

    def __init__(self, site_name='us', year=2023, week=18):
        self.site_name = site_name
        self.year = year
        self.week = week
        self.week = f'0{self.week}' if int(self.week) < 10 else f'{self.week}'
        # 数据库连接参数
        self.db_params = {
            "pg_us": {
                "host": "192.168.10.216",  # 数据库主机地址
                "port": 5432,  # 数据库端口号
                "dbname": "selection" if self.site_name == 'us' else f"selection_{self.site_name}",  # 数据库名称
                "user": "postgres",  # 数据库用户名
                "password": "T#4$4%qPbR7mJx"  # 数据库密码
            },
            "mysql_others": {
                "host": "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com",  # 数据库主机地址
                "port": 3306,  # 数据库端口号
                "dbname": "selection" if self.site_name == 'us' else f"selection_{self.site_name}",  # 数据库名称
                "user": "adv_test",  # 数据库用户名
                "password": "Yswg%40XP_test"  # 数据库密码
            }
        }
        self.engine_read, self.engine_save = self.create_connection()

    @staticmethod
    def get_weight(weight_str, site_name):
        # 提取到公共方法中 直接复制的
        return parse_weight_str(weight_str, site_name)

    def create_connection(self):
        # 建立数据库连接
        if self.site_name == 'us' and ((int(self.week) >= 18 and int(self.year) >= 2023) or (int(self.year) >= 2024)):
            db_params = self.db_params['pg_us']
            connection_string = f"postgresql+psycopg2://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
        else:
            db_params = self.db_params['mysql_others']
            connection_string = f"mysql+pymysql://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
        db_params = self.db_params['pg_us']
        connection_string_save = f"postgresql+psycopg2://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
        engine_save = create_engine(connection_string_save)
        return create_engine(connection_string), engine_save

    def read_data(self):
        print("开始读取数据")
        week_params = f"{int(self.week)}" if self.site_name == 'us' else f"{self.week}"
        sql = f"select asin, weight, weight_str from {self.site_name}_asin_detail_{self.year}_{week_params};"  # where weight_str is not null
        print("sql:", sql)
        return pd.read_sql(sql, con=self.engine_read)

    def handle_data(self):
        df = self.read_data()
        if df.shape[0] == 0:
            print("site_name, year, week:", self.site_name, self.year, self.week, "数据为空，退出")
        print("df.shape:", df.shape)
        print("开始处理数据")
        df.weight_str = df.weight_str.apply(lambda x: str(x).lower())
        # df['weight_info'] = df['weight_str'].apply(self.get_weight)
        df['weight_info'] = df.apply(lambda row: self.get_weight(row['weight_str'], self.site_name), axis=1)  # 传递多个参数
        # df[['weight', 'weight_type']] = df['weight_info'].str.split(',', expand=True)
        # tuple 展开
        df[['weight', 'weight_type']] = df['weight_info'].apply(pd.Series)

        df.weight = df.weight.apply(lambda x: np.nan if str(x) == 'none' else x)
        df.weight = df.weight.astype("float64")
        df.weight = df.weight.apply(lambda x: 0.001 if x <= 0.001 else x)
        df.weight_str = df.weight_str.apply(lambda x: np.nan if str(x) == 'none' else x)
        df = df.drop(columns=['weight_info'])
        df['date_info'] = f'{self.year}-{self.week}'
        return df

    def save_data(self):
        df = self.handle_data()
        print("开始存储数据: 先清空对应week的分区表")
        print(df.weight_type.value_counts(dropna=False))
        with self.engine_save.begin() as conn:
            sql = f"truncate {self.site_name}_asin_weight_{self.year}_{self.week};"
            print("清空sql:", sql)
            conn.execute(sql)
        df.to_sql(f"{self.site_name}_asin_weight_{self.year}_{self.week}", con=self.engine_save, if_exists='append', index=False,
                  chunksize=df.shape[0] // 10)


if __name__ == '__main__':
    site_name = sys.argv[1]  # 参数1：站点
    year = int(sys.argv[2])  # 参数2：类型：day/week/4_week/month/quarter
    week = int(sys.argv[3])  # 参数3：年-月-日/年-周/年-月/年-季, 比如: 2022-1
    handle_obj = CleanWeight(site_name=site_name, year=year, week=week)
    handle_obj.save_data()
    quit()

    site_name = 'de'
    site_name_list = ['us', 'de', 'uk', 'es', 'fr', 'it']
    week_list = [16, 17, 18, 19]
    year = 2023
    week = 19
    while True:
        try:
            for week in week_list:
                for site_name in site_name_list:
                    try:
                        handle_obj = CleanWeight(site_name=site_name, year=year, week=week)
                        handle_obj.save_data()
                    except Exception as e:
                        print("error_info:", traceback.format_exc(), e)
                        if site_name_list[-1] == site_name and week_list[-1] == week:
                            print("不满足运行条件，结束")
                            quit()
                        continue
            break
        except Exception as e:
            print("error_info:", traceback.format_exc(), e)
            continue
