import os
import sys
import pandas as pd

from templates_mysql import TemplatesMysql


class ExportStInfo(TemplatesMysql):

    def __init__(self, site_name_flag='all', year_week='2022-1'):
        """
        默认导入所有站点的所有data_type类型的表
        """
        super().__init__()
        if site_name_flag == 'all':
            self.site_name_list = ['us', 'uk', 'de', 'es', 'fr', 'it']
        else:
            self.site_name_list = [site_name_flag]
        self.year_week = year_week
        self.year = int(self.year_week.split("-")[0])
        self.week = int(self.year_week.split("-")[-1])
        self.path_sh = f"/opt/module/spark/demo/py_demo/demo_sqoop/export_st_info.sh"
        self.df_table_counts = pd.DataFrame()

    def export_data(self, site_name):
        with self.engine.begin() as conn:
            sql_delete = f"delete from {self.site_name}_brand_st_info where week={self.week}"
            conn.execute(sql_delete)
        print(f"开始导出{site_name}站点的数据")
        os.system(f"{self.path_sh} {site_name} {self.year} {self.week}")

    def check_data(self, site_name):
        self.site_name = site_name
        self.engine = self.mysql_connect()
        sql_read = f"select count(*) as table_counts from {self.site_name}_brand_st_info where week={self.week}"
        self.df_table_counts = pd.read_sql(sql_read, con=self.engine)
        table_counts = list(self.df_table_counts.table_counts)[0]
        print("table_counts:", table_counts)
        if table_counts == 0:
            self.export_data(site_name=site_name)

    def update_data(self):
        with self.engine.begin() as conn:
            conn.execute(f"set @week={self.week};")
            print(f"1. {self.site_name}--更新ao_val")
            sql_update_1 = f"""UPDATE {self.site_name}_brand_analytics_{self.year} a, {self.site_name}_brand_st_info b 
                                set a.ao_val=b.st_ao_val WHERE b.week={self.week} and b.st_ao_val>0 and a.id=b.st_brand_id;"""
            conn.execute(sql_update_1)
            print(f"2. {self.site_name}--更新is_first_text")
            sql_update_2 = f"""UPDATE {self.site_name}_brand_analytics_{self.year} a, {self.site_name}_brand_st_info b 
                                set a.is_first_text=b.st_is_first_text WHERE a.rank<=700000 and b.week={self.week} and b.st_is_first_text=1 and a.id=b.st_brand_id;"""
            conn.execute(sql_update_2)

    def run(self):
        for site_name in self.site_name_list:
            self.export_data(site_name=site_name)
            self.check_data(site_name=site_name)
            self.update_data()


if __name__ == '__main__':
    site_name_flag = sys.argv[1]  # 参数1:site_name列表-->all:所有站点
    year_week = sys.argv[2]  # 参数2:年-周, 比如: 2022-1
    handle_obj = ExportStInfo(site_name_flag=site_name_flag,
                              year_week=year_week)
    handle_obj.run()