# author : wangrui
# data : 2024/3/25 9:52
import time

from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from sqlalchemy import create_engine, text
import pandas as pd
import concurrent.futures
import json
import threading
from googletrans import Translator
from requests.exceptions import RequestException
Image.MAX_IMAGE_PIXELS = 200000000  # 设置更高的限制
device = torch.device("cpu")

# 历史数据
class ImageRecognitionHistory(object):
    def __init__(self):
        self.mysql_host = "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com"
        self.mysql_port = "3306"
        self.mysql_username = "adv_yswg"
        self.mysql_pwd = "HCL1zcUgQesaaXNLbL37O5KhpSAy0c"
        self.db_mame = "selection"
        self.mysql_engine = create_engine(
            f"mysql+pymysql://{self.mysql_username}:{self.mysql_pwd}@{self.mysql_host}:{self.mysql_port}/{self.db_mame}",
            pool_size=30,  # 连接池的大小
            max_overflow=30  # 超出连接池大小之外可以创建的连接数
        )
        # self.processor = DetrImageProcessor.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
        # self.model = DetrForObjectDetection.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
        # self.image_classification = pipeline(Tasks.image_classification,
        #                                 model='E:\Model\model_new\cv_nextvit-small_image-classification_Dailylife-labels')
        # 模型路径: hadoop15上 192.168.10.224
        self.processor = DetrImageProcessor.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
        self.model = DetrForObjectDetection.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
        self.image_classification = pipeline(Tasks.image_classification,
                                        model='/home/wangrui/cv_nextvit-small_image-classification_Dailylife-labels')
        self.df_image = pd.DataFrame()
        self.df_image_sub_list = []
        self.threads_num = 25

    def rea_and_prepare_data(self):
        print("开始准备图片数据")
        query = "SELECT id, files FROM file_files_test where module = 'Material' and english_elements is null"
        self.df_image = pd.read_sql(query, self.mysql_engine)
        sub_size = len(self.df_image) // self.threads_num
        print("数据总量是:", len(self.df_image))
        if self.threads_num == 1:
            self.df_image_sub_list.append(self.df_image)
            return
        if sub_size:
            for i in range(self.threads_num):
                start_idx = i * sub_size
                end_idx = (i + 1) * sub_size if i < (self.threads_num - 1) else len(self.df_image)
                sub_df = self.df_image.iloc[start_idx: end_idx]
                self.df_image_sub_list.append(sub_df)
            print("数据已准备好")
        else:
            print("没有信息,请注意检查!")
            quit(1)

    def thread_task(self, sub_df):
        print(f"线程-{threading.current_thread().ident} 开始处理数据")
        conn = self.mysql_engine.connect()
        count = 0
        for index, row in sub_df.iterrows():
            try:
                parse_value = json.loads(row['files'])
                img_path = parse_value['raw']
                image_url = f"http://soundasia.oss-cn-shenzhen.aliyuncs.com/{img_path}"
                english_elements, chinese_elements = self.get_image_element(image_url)
                if english_elements and chinese_elements:
                    id = row['id']
                    # 构造参数绑定的 SQL 语句模板
                    sql_template = text(f"""
                        UPDATE file_files_test 
                        SET english_elements = '{english_elements}',
                            chinese_elements = '{chinese_elements}'
                        WHERE id = {id}
                    """)
                    conn.execute(sql_template)
            except Exception as e:
                print(f"处理数据时出现异常: {e}")
            finally:
                count += 1
            if count % 200 == 0 and count > 0:
                print(f"线程-{threading.current_thread().ident} 已经处理了 {count}条数据")
        print(f"线程-{threading.current_thread().ident} 处理数据完成")

    def handle_image_history(self):
        print("开始图片元素标注")
        start_time = time.time()
        with concurrent.futures.ThreadPoolExecutor(25) as executor:
            futures = [executor.submit(self.thread_task, sub_df) for sub_df in self.df_image_sub_list]
            for future in concurrent.futures.as_completed(futures):
                future.result()  # 等待每个线程任务完成
        print("标签标注完成!")
        elapsed_time = time.time() - start_time
        print(f"耗时:{elapsed_time:.2f} 秒")

    # def handle_image_history(self):
    #     print("开始图片元素标注")
    #     self.thread_task(self.df_image_sub_list[0])
    #     print("标签标注完成!")

    def translate_with_retry(self, translator, text, src='en', dest='zh-CN', max_retries=10):
        retries = 0
        while retries < max_retries:
            try:
                translation = translator.translate(text, src=src, dest=dest)
                return translation.text
            except RequestException as e:
                print(f"Request failed: {e}")
                retries += 1
                if retries < max_retries:
                    print(f"Retrying... ({retries}/{max_retries})")
                    time.sleep(2 ** retries)  # Exponential backoff for retries
                else:
                    print("Max retries reached. Translation failed.")
                    return None

    def get_image_element(self, url):
        try:
            translator = Translator()
            image = Image.open(requests.get(url, stream=True, timeout=30).raw)
            if image.mode != "RGB":
                image = image.convert("RGB")
            inputs = self.processor(images=image, return_tensors="pt")
            outputs = self.model(**inputs)
            target_sizes = torch.tensor([image.size[::-1]])
            results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
        except Exception as e:
            print(f"该图片解析失败: {e}")
            return None, None
        englist_label_set = set()
        chinese_lable_set = set()
        for label in results['labels']:
            label = self.model.config.id2label[label.item()]
            if label:
                englist_label_set.add(label)
                chinese_label = self.translate_with_retry(translator, label, src='en', dest='zh-CN')
                if chinese_label:
                    chinese_lable_set.add(chinese_label)
        if not englist_label_set:
            results = self.image_classification(url)
            for score, label in zip(results['scores'], results['labels']):
                if score > 0.5:
                    english_label = self.translate_with_retry(translator, label, src='zh-CN', dest='en')
                    if english_label:
                        englist_label_set.add(english_label)
                    chinese_lable_set.add(label)
        if englist_label_set:
            return ','.join(englist_label_set), ','.join(chinese_lable_set)
        else:
            return None, None

    def run(self):
        self.rea_and_prepare_data()
        self.handle_image_history()


if __name__ == '__main__':
    handle_obj = ImageRecognitionHistory()
    handle_obj.run()