# author : wangrui # data : 2024/3/25 9:52 import time from transformers import DetrImageProcessor, DetrForObjectDetection import torch from PIL import Image import requests from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from sqlalchemy import create_engine, text import pandas as pd import concurrent.futures import json import threading from googletrans import Translator from requests.exceptions import RequestException Image.MAX_IMAGE_PIXELS = 200000000 # 设置更高的限制 device = torch.device("cpu") # 历史数据 class ImageRecognitionHistory(object): def __init__(self): self.mysql_host = "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com" self.mysql_port = "3306" self.mysql_username = "adv_yswg" self.mysql_pwd = "HCL1zcUgQesaaXNLbL37O5KhpSAy0c" self.db_mame = "selection" self.mysql_engine = create_engine( f"mysql+pymysql://{self.mysql_username}:{self.mysql_pwd}@{self.mysql_host}:{self.mysql_port}/{self.db_mame}", pool_size=30, # 连接池的大小 max_overflow=30 # 超出连接池大小之外可以创建的连接数 ) # self.processor = DetrImageProcessor.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm") # self.model = DetrForObjectDetection.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm") # self.image_classification = pipeline(Tasks.image_classification, # model='E:\Model\model_new\cv_nextvit-small_image-classification_Dailylife-labels') # 模型路径: hadoop15上 192.168.10.224 self.processor = DetrImageProcessor.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm") self.model = DetrForObjectDetection.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm") self.image_classification = pipeline(Tasks.image_classification, model='/home/wangrui/cv_nextvit-small_image-classification_Dailylife-labels') self.df_image = pd.DataFrame() self.df_image_sub_list = [] self.threads_num = 25 def rea_and_prepare_data(self): print("开始准备图片数据") query = "SELECT id, files FROM file_files_test where module = 'Material' and english_elements is null" self.df_image = pd.read_sql(query, self.mysql_engine) sub_size = len(self.df_image) // self.threads_num print("数据总量是:", len(self.df_image)) if self.threads_num == 1: self.df_image_sub_list.append(self.df_image) return if sub_size: for i in range(self.threads_num): start_idx = i * sub_size end_idx = (i + 1) * sub_size if i < (self.threads_num - 1) else len(self.df_image) sub_df = self.df_image.iloc[start_idx: end_idx] self.df_image_sub_list.append(sub_df) print("数据已准备好") else: print("没有信息,请注意检查!") quit(1) def thread_task(self, sub_df): print(f"线程-{threading.current_thread().ident} 开始处理数据") conn = self.mysql_engine.connect() count = 0 for index, row in sub_df.iterrows(): try: parse_value = json.loads(row['files']) img_path = parse_value['raw'] image_url = f"http://soundasia.oss-cn-shenzhen.aliyuncs.com/{img_path}" english_elements, chinese_elements = self.get_image_element(image_url) if english_elements and chinese_elements: id = row['id'] # 构造参数绑定的 SQL 语句模板 sql_template = text(f""" UPDATE file_files_test SET english_elements = '{english_elements}', chinese_elements = '{chinese_elements}' WHERE id = {id} """) conn.execute(sql_template) except Exception as e: print(f"处理数据时出现异常: {e}") finally: count += 1 if count % 200 == 0 and count > 0: print(f"线程-{threading.current_thread().ident} 已经处理了 {count}条数据") print(f"线程-{threading.current_thread().ident} 处理数据完成") def handle_image_history(self): print("开始图片元素标注") start_time = time.time() with concurrent.futures.ThreadPoolExecutor(25) as executor: futures = [executor.submit(self.thread_task, sub_df) for sub_df in self.df_image_sub_list] for future in concurrent.futures.as_completed(futures): future.result() # 等待每个线程任务完成 print("标签标注完成!") elapsed_time = time.time() - start_time print(f"耗时:{elapsed_time:.2f} 秒") # def handle_image_history(self): # print("开始图片元素标注") # self.thread_task(self.df_image_sub_list[0]) # print("标签标注完成!") def translate_with_retry(self, translator, text, src='en', dest='zh-CN', max_retries=10): retries = 0 while retries < max_retries: try: translation = translator.translate(text, src=src, dest=dest) return translation.text except RequestException as e: print(f"Request failed: {e}") retries += 1 if retries < max_retries: print(f"Retrying... ({retries}/{max_retries})") time.sleep(2 ** retries) # Exponential backoff for retries else: print("Max retries reached. Translation failed.") return None def get_image_element(self, url): try: translator = Translator() image = Image.open(requests.get(url, stream=True, timeout=30).raw) if image.mode != "RGB": image = image.convert("RGB") inputs = self.processor(images=image, return_tensors="pt") outputs = self.model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] except Exception as e: print(f"该图片解析失败: {e}") return None, None englist_label_set = set() chinese_lable_set = set() for label in results['labels']: label = self.model.config.id2label[label.item()] if label: englist_label_set.add(label) chinese_label = self.translate_with_retry(translator, label, src='en', dest='zh-CN') if chinese_label: chinese_lable_set.add(chinese_label) if not englist_label_set: results = self.image_classification(url) for score, label in zip(results['scores'], results['labels']): if score > 0.5: english_label = self.translate_with_retry(translator, label, src='zh-CN', dest='en') if english_label: englist_label_set.add(english_label) chinese_lable_set.add(label) if englist_label_set: return ','.join(englist_label_set), ','.join(chinese_lable_set) else: return None, None def run(self): self.rea_and_prepare_data() self.handle_image_history() if __name__ == '__main__': handle_obj = ImageRecognitionHistory() handle_obj.run()