import ast
import json
import os
import sys
import threading
import time
import traceback
import uuid

import pandas as pd
import logging

import redis
from sqlalchemy import text

sys.path.append("/opt/module/spark-3.2.0-bin-hadoop3.2/demo/py_demo/")
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)


class ImgToh7(object):

    def __init__(self, site_name='us', thread_num=10, limit=100, img_type='amazon_self'):
        self.site_name = site_name
        self.thread_num = thread_num
        self.limit = limit
        self.img_type = img_type
        self.img_table = self.get_table_name_and_dir_name()
        # self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
        self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
        self.df_self_asin_image = pd.DataFrame()
        self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
        self.local_name = f"{self.site_name}_img_to_h7"

    def get_table_name_and_dir_name(self):
        if self.img_type == 'amazon_self':
            image_table = f'{self.site_name}_self_asin_img'
        elif self.img_type == 'amazon_inv':
            image_table = f'{self.site_name}_inv_img'
        elif self.img_type == 'amazon':
            image_table = f'{self.site_name}_amazon_img'
        else:
            image_table = ""
        return image_table

    def get_dir_name(self, img_unique_value=''):
        if self.img_type in ['amazon_self', 'amazon']:
            dir_name = rf"/mnt/data/img_data/{self.img_type}/{self.site_name}/{img_unique_value[:1]}/{img_unique_value[:2]}/{img_unique_value[:3]}/{img_unique_value[:4]}/{img_unique_value[:5]}/{img_unique_value[:6]}"
        elif self.img_type == 'amazon_inv':
            dir_name = rf"/mnt/data/img_data/{self.img_type}/{self.site_name}/{img_unique_value.split('_')[0]}"
        else:
            dir_name = ""
        return dir_name

    def acquire_lock(self, lock_name, timeout=10):
        """
        尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
        lock_name: 锁的key, 建议和任务名称保持一致
        """
        lock_value = str(uuid.uuid4())
        lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout)  # 可以不设置超时时间
        # lock_acquired = self.client.set(lock_name, lock_value, nx=True)
        return lock_acquired, lock_value

    def release_lock(self, lock_name, lock_value):
        """释放分布式锁"""
        script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        result = self.client.eval(script, 1, lock_name, lock_value)
        return result

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name, timeout=100)
                if lock_acquired:
                    with self.engine_doris.begin() as conn:
                        # sql_read = text(f"SELECT id, img_unique, local_path, img_type FROM selection.{self.img_table} WHERE site_name='{self.site_name}' and img_type='{self.img_type}' and state=1 LIMIT {self.limit};")
                        # sql_read = f"""
                        #     SELECT id, img_unique, image_file from selection.{self.img_table} WHERE id in
                        #         (
                        #         select id from selection.{self.img_table} WHERE state=1 limit {self.limit}
                        #         );
                        # """
                        sql_read = f"""
                            SELECT id, img_unique, image_file from selection.{self.img_table} WHERE id in
                            (
                            select id from selection.us_inv_img WHERE state=1 and id 
                            BETWEEN (SELECT min(id) from selection.{self.img_table} WHERE state=1 )
                            and (SELECT min(id)+{self.limit} from selection.{self.img_table} WHERE state=1 )
                            );
                        """
                        # a = conn.execute(sql_read)
                        # df = pd.DataFrame(a, columns=['id', 'img_unique', 'image_file'])
                        df = pd.read_sql(sql_read, con=self.engine_doris)
                        df_img_file_null = df.loc[df.image_file.isna()]
                        df_img_file_not_null = df.loc[~df.image_file.isna()]
                        img_unique_tuple = tuple(df.img_unique)
                        img_unique_tuple_null = tuple(df_img_file_null.img_unique)
                        img_unique_tuple_not_null = tuple(df_img_file_not_null.img_unique)
                        print(f"sql_read: {sql_read}, {df.shape}, df_img_file_null.shape: {df_img_file_null.shape}", img_unique_tuple[:10])
                        # 更新状态:
                        if img_unique_tuple_not_null:
                            img_unique_tuple_not_null_str = f"('{img_unique_tuple_not_null[0]}')" if len(img_unique_tuple_not_null) == 1 else f"{img_unique_tuple_not_null}"
                            # sql_update = f"update selection.{self.img_table} set state=2 where img_unique in ({','.join(map(str, img_unique_tuple_not_null))});"  # 解析存储中
                            sql_update = f"update selection.{self.img_table} set state=2 where img_unique in {img_unique_tuple_not_null_str};"  # 解析存储中
                            print("sql_update--2:", sql_update[:100])
                            conn.execute(sql_update)
                        # 更新状态:
                        if img_unique_tuple_null:
                            img_unique_tuple_null_str = f"('{img_unique_tuple_null[0]}')" if len(img_unique_tuple_null) == 1 else f"{img_unique_tuple_null}"
                            # sql_update = f"update selection.{self.img_table} set state=9 where img_unique in ({','.join(map(str, img_unique_tuple))});"  # img_file字段为空值
                            sql_update = f"update selection.{self.img_table} set state=9 where img_unique in {img_unique_tuple_null_str};"  # img_file字段为空值
                            print("sql_update--9:", sql_update[:100])
                            conn.execute(sql_update)
                        self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                        return df_img_file_not_null, df_img_file_null
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(5)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(5)
                self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
                continue

    def read_data_old(self):
        with self.engine_doris.begin() as conn:
            sql_update = f"""
                UPDATE {self.img_table}
                SET state = 2
                WHERE id in (
                    SELECT id
                    FROM {self.img_table}
                    WHERE state = 1
                    LIMIT {self.limit}
                );
            """
            print(f"sql_update: {sql_update}")
            conn.execute(sql_update)
            sql_read = f"select id, img_unique, image_file from {self.img_table} where state=2 limit {self.limit};"
            # self.df_self_asin_image = pd.read_sql(sql_read, con=self.engine_doris)
            a = conn.execute(sql_read)
            df = pd.DataFrame(a, columns=['id', 'img_unique', 'image_file'])
            id_tuple = tuple(df.id)
            print(f"sql_read: {sql_read}, {df.shape}", id_tuple[:10])
        return df

            # if id_tuple:
            #     id_tuple_str = f"({id_tuple[0]})" if len(id_tuple) == 1 else f"{id_tuple}"
            #     sql_update = f"update {self.site_name}_self_asin_image set state=2 where id in {id_tuple_str};"
            #     conn.execute(sql_update)

    def handle_data(self, df):
        img_unique_list = list(df.img_unique)
        img_str_list = list(df.image_file)
        print(f"{len(img_unique_list)}, {img_unique_list[:10]}")
        for img_unique_value, img_str in zip(img_unique_list, img_str_list):
            # print(f"img_unique_value, img_str: {img_unique_value, img_str}")
            # print(f"img_str: {type(img_str)}, {img_str[:20]}")
            # img_str = json.loads(img_str)
            input_bytes = ast.literal_eval(img_str)
            # input_bytes = img_str
            dir_name = self.get_dir_name(img_unique_value=img_unique_value)
            # 确保目录存在
            os.makedirs(dir_name, exist_ok=True)
            file_name = rf"{dir_name}/{img_unique_value}.jpg"
            # print(f"file_name: {file_name}")
            with open(file_name, 'wb') as f:
                f.write(input_bytes)
            # break

    def save_data(self, df):
        with self.engine_doris.begin() as conn:
            img_unique_tuple = tuple(df.img_unique)
            if img_unique_tuple:
                img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
                sql_update = f"update {self.img_table} set state=3 where img_unique in {img_unique_tuple_str};"
                print(f"sql_update: {sql_update[:100]}")
                conn.execute(sql_update)

    def run(self):
        while True:
            try:
                df, df_null = self.read_data()
                if df.shape[0] or df_null.shape[0]:
                    self.handle_data(df)
                    self.save_data(df)
                    # break
                else:
                    break
            except Exception as e:
                print(f"error: {e}", traceback.format_exc())
                self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                time.sleep(10)

    def run_thread(self):
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run)
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")


if __name__ == '__main__':
    site_name = sys.argv[1]  # 参数1：站点
    img_type = sys.argv[2]  # 参数2：图片来源类型
    limit = int(sys.argv[3])  # 参数3：每次读取的数量--1000
    thread_num = int(sys.argv[4])  # 参数4：线程数量--5
    # handle_obj = ImgToh7(site_name='us', img_type='amazon_inv', limit=1, thread_num=1)
    handle_obj = ImgToh7(site_name=site_name, img_type=img_type, limit=limit, thread_num=thread_num)
    handle_obj.run_thread()