import sys
import os

sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.db_connect import BaseUtils
from multiprocessing import Pool
import pandas as pd
import traceback
from queue import Queue
import threading
import json
import ast
import gzip
from utils.parse_search_term_xpath import ParseSearchTermUs
import socket

# 数据库连接
engine_strrocks = BaseUtils().starrocks_connect()

engine_pg14 = BaseUtils().pg_connect()
import time


class Parse_search_term_html():
    def __init__(self, site_name=None, date_info=None):
        self.site_name = site_name  # 站点
        self.date_info = date_info
        self.read_size = 10
        self.search_term_html_queue = Queue()
        self.spider_int = 0
        self.month = self.date_info.split('-')[-1]
        self.search_term_state_queue = Queue()
        self.db_syn = f'search_term_html_{date_info.replace("-", "_")}'
        self.item_queue = Queue()
        self.buyBox_list = []

        self.engine_strrocks = BaseUtils().starrocks_connect()
        self.engine_pg = BaseUtils().pg_connect()
        self.zr_all_list = []
        self.sp_all_list = []
        self.sb_all_list = []
        self.ac_all_list = []
        self.bs_all_list = []
        self.er_all_list = []
        self.tr_all_list = []
        self.buy_text_list = []
        self.hr_list = []
        self.sort_all_list = []
        self.st_list = []

    def decompress_bytes(self, input_bytes):
        if isinstance(input_bytes, str):
            input_bytes = ast.literal_eval(input_bytes)
        return gzip.decompress(input_bytes).decode('utf-8')

    def get_search_term_html(self):
        while True:
            if self.search_term_html_queue.empty() == False:
                search_term_html_str = self.search_term_html_queue.get()
                search_term_html_str_list = search_term_html_str.split('|-|-|-|-|-|')
                search_term = search_term_html_str_list[0]
                search_term_b_html = search_term_html_str_list[1]
                k_id = int(search_term_html_str_list[2])
                page = 1
                html_str = json.loads(search_term_b_html)
                search_term_html = self.decompress_bytes(html_str)  # 解压缩字节对象
                parse_search_term = ParseSearchTermUs(page_source=search_term_html, driver=None,
                                                      search_term=search_term,
                                                      page=page, site_name=self.site_name)

                st_list = parse_search_term.run()
                self.st_list.append(st_list)
                for st_list in self.st_list:
                    zr_list, sp_list, sb_list, ac_list, bs_list, er_list, tr_list, sort_list, buy_text_list, hr_list = st_list
                    self.zr_all_list.extend(zr_list)
                    self.sp_all_list.extend(sp_list)
                    self.sb_all_list.extend(sb_list)
                    self.ac_all_list.extend(ac_list)
                    self.bs_all_list.extend(bs_list)
                    self.er_all_list.extend(er_list)
                    self.tr_all_list.extend(tr_list)
                    self.buy_text_list.extend(buy_text_list)
                    self.hr_list.extend(hr_list)
                    self.sort_all_list.extend(sort_list)
                self.search_term_state_queue.put(k_id)
            else:
                print('当前线程完成')
                break

    def db_change_state_common(self):
        while True:
            search_term_list = []
            for _ in range(1000):
                if not self.search_term_state_queue.empty():
                    id = self.search_term_state_queue.get()
                    search_term_list.append(id)
                else:
                    break
            # 如果search_term_list不为空，进行数据库更新
            print('state=3search_term_list:', len(search_term_list))
            if search_term_list:
                while True:
                    try:
                        with self.engine_strrocks.begin() as conn:
                            # sql_column = 'set partial_update_mode="column";'
                            # conn.execute(sql_column)
                            if len(search_term_list) == 1:
                                sql_update = f"UPDATE {self.db_syn} SET state=3 WHERE id={search_term_list[0]} AND state=2;"
                            else:
                                sql_update = f"UPDATE {self.db_syn} SET state=3 WHERE id IN {tuple(search_term_list)} AND state=2;"
                            conn.execute(sql_update)
                        break
                    except Exception as e:
                        print(f"更改{self.db_syn}表错误", e, f"\n{traceback.format_exc()}")
                        time.sleep(15)
                        continue
            else:
                # 如果search_term_list为空，说明队列已经处理完毕，跳出外层循环
                break

    def init_list(self):
        self.search_term_state_queue = Queue()
        self.search_term_html_queue = Queue()
        self.item_queue = Queue()
        self.buyBox_list = []
        self.buyBoxname_search_term_list = []
        self.bs_category_search_term_list_pg = []
        self.all_img_video_list = []
        self.search_term_variation_list = []
        self.zr_all_list = []
        self.sp_all_list = []
        self.sb_all_list = []
        self.ac_all_list = []
        self.bs_all_list = []
        self.er_all_list = []
        self.tr_all_list = []
        self.buy_text_list = []
        self.hr_list = []
        self.sort_all_list = []
        self.df_asin_detail_simply_list = []

    def db_update_brand(self):
        if self.sort_all_list:
            while True:
                try:
                    self.engine_pg = BaseUtils().pg_connect()
                    print(len(self.sort_all_list))
                    df_being_sold = pd.DataFrame(data=self.sort_all_list,
                                                 columns=['search_term', 'quantity_being_sold',
                                                          'quantity_being_sold_str'])
                    # 获取成功抓取的搜索词来更改状态 3
                    df_being_sold['month'] = self.month
                    df_being_sold['date_info'] = self.date_info
                    # df_being_sold['page'] =
                    year_moth_list = self.date_info.split('-')
                    print('year_moth_list::,', year_moth_list)
                    print(f'存储表：：de_brand_analytics_month_pyb__{year_moth_list[0]}')
                    df_being_sold.drop_duplicates(['search_term', 'quantity_being_sold','page'], inplace=True)  # 去重
                    if df_being_sold.shape[0] > 0:
                        df_being_sold.to_sql(f'de_brand_analytics_month_pyb_{year_moth_list[0]}', con=self.engine_pg,
                                             if_exists='append',
                                             index=False)
                    break
                except Exception as e:
                    print('db_update_brand::', e, f"\n{traceback.format_exc()}")
                    time.sleep(5)
                    continue

    def run_pol(self, search_term_html_list):
        for search_term_html in search_term_html_list:
            self.search_term_html_queue.put(search_term_html)
        html_thread = []
        for i in range(50):
            thread2 = threading.Thread(target=self.get_search_term_html)
            html_thread.append(thread2)
        for ti in html_thread:
            ti.start()
        for t2 in html_thread:
            t2.join()
        print(self.sort_all_list)
        self.db_update_brand()
        self.db_change_state_common()
        self.init_list()


def db_read_data_common(start_id, limit, site_name, db_search_term):
    while True:
        try:
            with engine_strrocks.begin() as conn:
                sql_read = f"SELECT id, search_term, html FROM {db_search_term} WHERE state=1 and page = 2 and site_name='{site_name}' and id BETWEEN {start_id} AND {start_id + limit - 1}"
                print(sql_read)
                a = conn.execute(sql_read)
                df_read = pd.DataFrame(a, columns=['id', 'search_term', 'html'])
                if df_read.shape[0] > 0:
                    id_tuple = tuple(df_read.id)
                    if len(id_tuple) == 1:
                        sql_update = f'UPDATE {db_search_term} set state=2 where id in ({id_tuple[0]});'
                    else:
                        sql_update = f'UPDATE {db_search_term} set state=2 where id in {id_tuple};'
                    conn.execute(sql_update)
                    search_term_list = list(
                        df_read.search_term + '|-|-|-|-|-|' + df_read.html + '|-|-|-|-|-|' + df_read.id.astype("U"))
                    return search_term_list
                else:
                    return []
        except Exception as e:
            print("读取数据出bug并等待5s继续", e, f"\n{traceback.format_exc()}")
            time.sleep(15)
            continue


def worker(start_id, limit, site_name, date_info, table_name):
    search_term_list = db_read_data_common(start_id, limit, site_name, table_name)
    if search_term_list:
        Parse_search_term_html(site_name=site_name, date_info=date_info).run_pol(search_term_list)


def get_search_term_html_count(site_name, date_info, table_name):
    sql_read = f"SELECT max(id), min(id) FROM {table_name} WHERE site_name='{site_name}'"
    print(sql_read)
    df = pd.read_sql(sql_read, con=engine_strrocks)
    max_id = df.iloc[0, 0]
    min_id = df.iloc[0, 1]
    print(max_id)
    print(min_id)
    return max_id, min_id


def get_ip_address():
    # 返回内网ip
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.connect(('baidu.com', 0))
    ip = s.getsockname()[0]
    print('内网ip：', ip)
    return ip


def sava_maxid_minid(data_range_list, site_name):
    with engine_pg14.begin() as conn:
        sql = f'truncate {site_name}_search_term_html_maxid_minid'
        conn.execute(sql)
    df_id = pd.DataFrame(data=data_range_list, columns=['maxid', 'batch_size'])
    df_id.to_sql(f'{site_name}_search_term_html_maxid_minid', con=engine_pg14,
                 if_exists='append',
                 index=False)


def get_sava_maxid_minid(site_name):
    with engine_pg14.begin() as conn:
        sql_read = f"select id, maxid, batch_size from {site_name}_search_term_html_maxid_minid where state=1 limit 1 for update"
        print(sql_read)
        a = conn.execute(sql_read)
        df_read = pd.DataFrame(a, columns=['id', 'maxid', 'batch_size'])
        if df_read.shape[0] > 0:
            id_tuple = tuple(df_read.id)
            if len(id_tuple) == 1:
                sql_update = f'UPDATE {site_name}_search_term_html_maxid_minid set state=2 where id in ({id_tuple[0]});'
            else:
                sql_update = f'UPDATE {site_name}_search_term_html_maxid_minid set state=2 where id in {id_tuple};'
            conn.execute(sql_update)
            maxid_minid_list = list(df_read.maxid.astype("U") + '|-|-|-|-|-|' + df_read.batch_size.astype("U"))
            return maxid_minid_list
        else:
            return []


def split_task(start_id, limit, num_splits):
    split_limits = []
    sub_limit = limit // num_splits
    for i in range(num_splits):
        split_start_id = start_id + i * sub_limit
        if i == num_splits - 1:  # The last split takes the remaining range
            split_limits.append((split_start_id, limit - i * sub_limit))
        else:
            split_limits.append((split_start_id, sub_limit))
    return split_limits


if __name__ == '__main__':
    site_name = 'us'  # 站点
    date_info = '2025-01'  # date_info
    batch_size = 100  # 每个批次查询5000条
    num_processes = 3  # 开启多少个进程
    table_name = f"search_term_html_{date_info.replace('-', '_')}"
    if get_ip_address() != '192.168.200.210':
        max_id, min_id = get_search_term_html_count(site_name, date_info, table_name)
        total_data = max_id - min_id + 1  # 计算总的数据量，加1是因为要包含最小和最大ID
        # 计算每个进程处理的数据范围
        data_range = []
        start_id = min_id
        while start_id <= max_id:  # 开始id == 最大id 跳出循环
            # end_id 每次加5000
            end_id = min(start_id + batch_size - 1, max_id)  # 防止超出最大ID,
            data_range.append((start_id, end_id - start_id + 1))  # 计算每个范围的行数
            start_id = end_id + 1  # 更新起始ID为下一个范围的起始ID
        print(data_range, site_name)
        # with open(rf'D:\新建文件夹\requests_files/us_B09NW2R5HQ.txt', 'w', encoding='utf-8')as f:
        #     f.write(str(data_range))
        # sava_maxid_minid(data_range, site_name)

    # while True:
    #     maxid_minid_list = get_sava_maxid_minid(site_name)
    #     if not maxid_minid_list:
    #         print("所有批次处理完成")
    #         break
    #     p = Pool(num_processes)
    #     print('maxid_minid_list::', maxid_minid_list)
    #     start_id_limit_list = []
    #     for start_limit in maxid_minid_list:
    #         start_id_limit = start_limit.split('|-|-|-|-|-|')
    #         start_id = int(start_id_limit[0])
    #         limit = int(start_id_limit[1])
    #         # 将任务进一步拆分为10个子任务
    #         sub_tasks = split_task(start_id, limit, num_processes)
    #         for sub_start_id, sub_limit in sub_tasks:
    #             p.apply_async(worker, args=(sub_start_id, sub_limit, site_name, date_info, table_name))
    #     p.close()
    #     p.join()
    #     print('结束当前进程')
    #     time.sleep(3)
    #     # 数据库连接
    #     engine_pg14 = BaseUtils().pg_connect()
    #     break
