ImageRecognitionHistory.py 7.63 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
# author : wangrui
# data : 2024/3/25 9:52
import time

from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from sqlalchemy import create_engine, text
import pandas as pd
import concurrent.futures
import json
import threading
from googletrans import Translator
from requests.exceptions import RequestException
Image.MAX_IMAGE_PIXELS = 200000000  # 设置更高的限制
device = torch.device("cpu")

# 历史数据
class ImageRecognitionHistory(object):
    def __init__(self):
        self.mysql_host = "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com"
        self.mysql_port = "3306"
        self.mysql_username = "adv_yswg"
        self.mysql_pwd = "HCL1zcUgQesaaXNLbL37O5KhpSAy0c"
        self.db_mame = "selection"
        self.mysql_engine = create_engine(
            f"mysql+pymysql://{self.mysql_username}:{self.mysql_pwd}@{self.mysql_host}:{self.mysql_port}/{self.db_mame}",
            pool_size=30,  # 连接池的大小
            max_overflow=30  # 超出连接池大小之外可以创建的连接数
        )
        # self.processor = DetrImageProcessor.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
        # self.model = DetrForObjectDetection.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
        # self.image_classification = pipeline(Tasks.image_classification,
        #                                 model='E:\Model\model_new\cv_nextvit-small_image-classification_Dailylife-labels')
        # 模型路径: hadoop15上 192.168.10.224
        self.processor = DetrImageProcessor.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
        self.model = DetrForObjectDetection.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
        self.image_classification = pipeline(Tasks.image_classification,
                                        model='/home/wangrui/cv_nextvit-small_image-classification_Dailylife-labels')
        self.df_image = pd.DataFrame()
        self.df_image_sub_list = []
        self.threads_num = 25

    def rea_and_prepare_data(self):
        print("开始准备图片数据")
        query = "SELECT id, files FROM file_files_test where module = 'Material' and english_elements is null"
        self.df_image = pd.read_sql(query, self.mysql_engine)
        sub_size = len(self.df_image) // self.threads_num
        print("数据总量是:", len(self.df_image))
        if self.threads_num == 1:
            self.df_image_sub_list.append(self.df_image)
            return
        if sub_size:
            for i in range(self.threads_num):
                start_idx = i * sub_size
                end_idx = (i + 1) * sub_size if i < (self.threads_num - 1) else len(self.df_image)
                sub_df = self.df_image.iloc[start_idx: end_idx]
                self.df_image_sub_list.append(sub_df)
            print("数据已准备好")
        else:
            print("没有信息,请注意检查!")
            quit(1)

    def thread_task(self, sub_df):
        print(f"线程-{threading.current_thread().ident} 开始处理数据")
        conn = self.mysql_engine.connect()
        count = 0
        for index, row in sub_df.iterrows():
            try:
                parse_value = json.loads(row['files'])
                img_path = parse_value['raw']
                image_url = f"http://soundasia.oss-cn-shenzhen.aliyuncs.com/{img_path}"
                english_elements, chinese_elements = self.get_image_element(image_url)
                if english_elements and chinese_elements:
                    id = row['id']
                    # 构造参数绑定的 SQL 语句模板
                    sql_template = text(f"""
                        UPDATE file_files_test 
                        SET english_elements = '{english_elements}',
                            chinese_elements = '{chinese_elements}'
                        WHERE id = {id}
                    """)
                    conn.execute(sql_template)
            except Exception as e:
                print(f"处理数据时出现异常: {e}")
            finally:
                count += 1
            if count % 200 == 0 and count > 0:
                print(f"线程-{threading.current_thread().ident} 已经处理了 {count}条数据")
        print(f"线程-{threading.current_thread().ident} 处理数据完成")

    def handle_image_history(self):
        print("开始图片元素标注")
        start_time = time.time()
        with concurrent.futures.ThreadPoolExecutor(25) as executor:
            futures = [executor.submit(self.thread_task, sub_df) for sub_df in self.df_image_sub_list]
            for future in concurrent.futures.as_completed(futures):
                future.result()  # 等待每个线程任务完成
        print("标签标注完成!")
        elapsed_time = time.time() - start_time
        print(f"耗时:{elapsed_time:.2f} 秒")

    # def handle_image_history(self):
    #     print("开始图片元素标注")
    #     self.thread_task(self.df_image_sub_list[0])
    #     print("标签标注完成!")

    def translate_with_retry(self, translator, text, src='en', dest='zh-CN', max_retries=10):
        retries = 0
        while retries < max_retries:
            try:
                translation = translator.translate(text, src=src, dest=dest)
                return translation.text
            except RequestException as e:
                print(f"Request failed: {e}")
                retries += 1
                if retries < max_retries:
                    print(f"Retrying... ({retries}/{max_retries})")
                    time.sleep(2 ** retries)  # Exponential backoff for retries
                else:
                    print("Max retries reached. Translation failed.")
                    return None

    def get_image_element(self, url):
        try:
            translator = Translator()
            image = Image.open(requests.get(url, stream=True, timeout=30).raw)
            if image.mode != "RGB":
                image = image.convert("RGB")
            inputs = self.processor(images=image, return_tensors="pt")
            outputs = self.model(**inputs)
            target_sizes = torch.tensor([image.size[::-1]])
            results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
        except Exception as e:
            print(f"该图片解析失败: {e}")
            return None, None
        englist_label_set = set()
        chinese_lable_set = set()
        for label in results['labels']:
            label = self.model.config.id2label[label.item()]
            if label:
                englist_label_set.add(label)
                chinese_label = self.translate_with_retry(translator, label, src='en', dest='zh-CN')
                if chinese_label:
                    chinese_lable_set.add(chinese_label)
        if not englist_label_set:
            results = self.image_classification(url)
            for score, label in zip(results['scores'], results['labels']):
                if score > 0.5:
                    english_label = self.translate_with_retry(translator, label, src='zh-CN', dest='en')
                    if english_label:
                        englist_label_set.add(english_label)
                    chinese_lable_set.add(label)
        if englist_label_set:
            return ','.join(englist_label_set), ','.join(chinese_lable_set)
        else:
            return None, None

    def run(self):
        self.rea_and_prepare_data()
        self.handle_image_history()


if __name__ == '__main__':
    handle_obj = ImageRecognitionHistory()
    handle_obj.run()