import os
import socket
import sys
import threading
import logging
import time
import traceback
import uuid

import pandas as pd
import redis
import requests
from sqlalchemy import text

logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录

from utils.db_util import DbTypes, DBUtil, get_redis_h14


class ImgDownload(object):

    def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10, limit=200):
        self.site_name = site_name
        self.img_type = img_type
        self.thread_num = thread_num
        self.limit = limit
        self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
        self.client_redis = get_redis_h14()
        self.hostname = socket.gethostname()
        self.first_local_dir, self.read_table = self.get_first_local_dir()
        # self.read_table = f"{self.site_name}_inv_img_info"
        self.local_name = self.read_table

    def get_first_local_dir(self):
        if self.img_type == 'amazon_self':
            first_local_dir = f"/mnt/data/img_data/amazon_self/{self.site_name}"
            image_table = f'{self.site_name}_self_asin_image'
        elif self.img_type == 'amazon':
            first_local_dir = f"/mnt/data/img_data/amazon/{self.site_name}"
            image_table = f'{self.site_name}_amazon_image'
        elif self.img_type == 'amazon_inv':
            first_local_dir = f"/mnt/data/img_data/amazon_inv/{self.site_name}"
            image_table = f'{self.site_name}_inv_img_info'
        else:
            first_local_dir = ""
            image_table = ""
        return first_local_dir, image_table

    def acquire_lock(self, lock_name, timeout=100):
        """
        尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
        lock_name: 锁的key, 建议和任务名称保持一致
        """
        lock_value = str(uuid.uuid4())
        lock_acquired = self.client_redis.set(lock_name, lock_value, nx=True, ex=timeout)  # 可以不设置超时时间
        # lock_acquired = self.client_redis.set(lock_name, lock_value, nx=True)
        return lock_acquired, lock_value

    def release_lock(self, lock_name, lock_value):
        """释放分布式锁"""
        script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        result = self.client_redis.eval(script, 1, lock_name, lock_value)
        return result

    @staticmethod
    def img_download(img_url, img_path, img_name):
        file_path = f"{img_path}{img_name}"
        for d_num in range(5):
            try:
                response = requests.get(img_url)
                if response.status_code == 200:
                    # Open a file in binary write mode
                    with open(file_path, 'wb') as file:
                        file.write(response.content)
                        # print("Image downloaded successfully.")
                        return True
                else:
                    continue
            except Exception as e:
                error = "No such file or directory"
                if error in str(e):
                    os.makedirs(img_path)
                print(f"{d_num}次--下载图片失败, 图片路径:　{file_path}, 图片url: {img_url}, \n错误信息: {e, traceback.format_exc()}")
                time.sleep(2)
        return False

    def update_state(self, id_list, state, state_value="success"):
        if id_list:
            while True:
                try:
                    with self.engine_mysql.begin() as conn:
                        id_tuple = tuple(id_list)
                        print(f"{state_value}--id_tuple: {len(id_tuple)}, {id_tuple[:10]}", )
                        if id_tuple:
                            id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
                            sql_update = f"UPDATE {self.read_table} SET state={state} WHERE id IN {id_tuple_str};"
                            print("sql_update:", sql_update[:150])
                            conn.execute(sql_update)
                    break
                except Exception as e:
                    print(f"读取数据错误: {e}", traceback.format_exc())
                    time.sleep(20)
                    self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
                    self.client_redis = get_redis_h14()
                    continue

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
                if lock_acquired:
                    print("self.hostname:", self.hostname)
                    with self.engine_mysql.begin() as conn:
                        sql_read = text(f"SELECT id, img_id, img_type, img_url, id_segment FROM {self.read_table} WHERE state=1 LIMIT {self.limit};")
                        df = pd.read_sql(sql=sql_read, con=self.engine_mysql)
                        id_tuple = tuple(df.id)
                        print(f"sql_read: {sql_read}, {df.shape}", id_tuple[:10])
                        if id_tuple:
                            id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
                            sql_update = f"UPDATE {self.read_table} SET state=2 WHERE id IN {id_tuple_str};"
                            print("sql_update:", sql_update[:150])
                            conn.execute(sql_update)
                    self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                    return df
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(10)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(20)
                self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
                self.client_redis = get_redis_h14()
                continue

    def handle_data(self, df, thread_id):
        # 1. 下载图片
        img_success_id_list = []
        img_failed_id_list = []
        id_list = list(df.id)
        id_len = len(id_list)
        for id_segment, id, img_id, img_type, img_url in zip(df.id_segment, df.id, df.img_id, df.img_type, df.img_url):
            img_path = f"{self.first_local_dir}/{id_segment}/"
            img_name = f"{id_segment}_{id}_{img_id}_{img_type}.jpg"
            if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
                img_path = img_path.replace("/mnt", "/home")
            d_flag = self.img_download(img_url=img_url, img_path=img_path, img_name=img_name)
            id_index = id_list.index(id)
            print(f"self.hostname: {self.hostname}, 线程: {thread_id}, 是否成功: {d_flag}, id_index: {id_index}, 进度: {round(id_index/id_len * 100, 2)}%, img_path: {img_path}{img_name}")
            if d_flag:
                img_success_id_list.append(id)
            else:
                img_failed_id_list.append(id)
        # 2. 更改状态 -- 成功3 失败4
        print(f"success: {len(img_success_id_list)}, failed: {len(img_failed_id_list)}")
        self.update_state(id_list=img_success_id_list, state=3, state_value="success")
        self.update_state(id_list=img_failed_id_list, state=4, state_value="failed")

    def save_data(self):
        pass

    def run(self, thread_id=1):
        while True:
            try:
                df = self.read_data()
                if df.shape[0]:
                    self.handle_data(df=df, thread_id=thread_id)
                    self.save_data()
                    # break
                else:
                    break
            except Exception as e:
                print(e, traceback.format_exc())
                self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
                self.client_redis = get_redis_h14()
                time.sleep(20)
                continue

    def run_thread(self):
        logging.info("所有线程处理开始")
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run, args=(thread_id, ))
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")


if __name__ == '__main__':
    # handle_obj = PicturesFeatures(self_flag='_self')
    # site_name = int(sys.argv[1])  # 参数1：站点
    # site_name = 'us'
    # img_type = "amazon_inv"
    # limit = 100
    # thread_num = 1
    site_name = sys.argv[1]  # 参数1：站点
    img_type = sys.argv[2]  # 参数2：图片来源类型
    limit = int(sys.argv[3])  # 参数3：每次读取的数量--1000
    thread_num = int(sys.argv[4])  # 参数4：线程数量--5
    handle_obj = ImgDownload(site_name=site_name, img_type=img_type, thread_num=thread_num, limit=limit)
    # handle_obj.run()
    handle_obj.run_thread()