1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# author : wangrui
# data : 2024/3/25 9:52
import time
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from sqlalchemy import create_engine, text
import pandas as pd
import concurrent.futures
import json
import threading
from googletrans import Translator
from requests.exceptions import RequestException
Image.MAX_IMAGE_PIXELS = 200000000 # 设置更高的限制
device = torch.device("cpu")
# 历史数据
class ImageRecognitionHistory(object):
def __init__(self):
self.mysql_host = "rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com"
self.mysql_port = "3306"
self.mysql_username = "adv_yswg"
self.mysql_pwd = "HCL1zcUgQesaaXNLbL37O5KhpSAy0c"
self.db_mame = "selection"
self.mysql_engine = create_engine(
f"mysql+pymysql://{self.mysql_username}:{self.mysql_pwd}@{self.mysql_host}:{self.mysql_port}/{self.db_mame}",
pool_size=30, # 连接池的大小
max_overflow=30 # 超出连接池大小之外可以创建的连接数
)
# self.processor = DetrImageProcessor.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
# self.model = DetrForObjectDetection.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
# self.image_classification = pipeline(Tasks.image_classification,
# model='E:\Model\model_new\cv_nextvit-small_image-classification_Dailylife-labels')
# 模型路径: hadoop15上 192.168.10.224
self.processor = DetrImageProcessor.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
self.model = DetrForObjectDetection.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
self.image_classification = pipeline(Tasks.image_classification,
model='/home/wangrui/cv_nextvit-small_image-classification_Dailylife-labels')
self.df_image = pd.DataFrame()
self.df_image_sub_list = []
self.threads_num = 25
def rea_and_prepare_data(self):
print("开始准备图片数据")
query = "SELECT id, files FROM file_files_test where module = 'Material' and english_elements is null"
self.df_image = pd.read_sql(query, self.mysql_engine)
sub_size = len(self.df_image) // self.threads_num
print("数据总量是:", len(self.df_image))
if self.threads_num == 1:
self.df_image_sub_list.append(self.df_image)
return
if sub_size:
for i in range(self.threads_num):
start_idx = i * sub_size
end_idx = (i + 1) * sub_size if i < (self.threads_num - 1) else len(self.df_image)
sub_df = self.df_image.iloc[start_idx: end_idx]
self.df_image_sub_list.append(sub_df)
print("数据已准备好")
else:
print("没有信息,请注意检查!")
quit(1)
def thread_task(self, sub_df):
print(f"线程-{threading.current_thread().ident} 开始处理数据")
conn = self.mysql_engine.connect()
count = 0
for index, row in sub_df.iterrows():
try:
parse_value = json.loads(row['files'])
img_path = parse_value['raw']
image_url = f"http://soundasia.oss-cn-shenzhen.aliyuncs.com/{img_path}"
english_elements, chinese_elements = self.get_image_element(image_url)
if english_elements and chinese_elements:
id = row['id']
# 构造参数绑定的 SQL 语句模板
sql_template = text(f"""
UPDATE file_files_test
SET english_elements = '{english_elements}',
chinese_elements = '{chinese_elements}'
WHERE id = {id}
""")
conn.execute(sql_template)
except Exception as e:
print(f"处理数据时出现异常: {e}")
finally:
count += 1
if count % 200 == 0 and count > 0:
print(f"线程-{threading.current_thread().ident} 已经处理了 {count}条数据")
print(f"线程-{threading.current_thread().ident} 处理数据完成")
def handle_image_history(self):
print("开始图片元素标注")
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(25) as executor:
futures = [executor.submit(self.thread_task, sub_df) for sub_df in self.df_image_sub_list]
for future in concurrent.futures.as_completed(futures):
future.result() # 等待每个线程任务完成
print("标签标注完成!")
elapsed_time = time.time() - start_time
print(f"耗时:{elapsed_time:.2f} 秒")
# def handle_image_history(self):
# print("开始图片元素标注")
# self.thread_task(self.df_image_sub_list[0])
# print("标签标注完成!")
def translate_with_retry(self, translator, text, src='en', dest='zh-CN', max_retries=10):
retries = 0
while retries < max_retries:
try:
translation = translator.translate(text, src=src, dest=dest)
return translation.text
except RequestException as e:
print(f"Request failed: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying... ({retries}/{max_retries})")
time.sleep(2 ** retries) # Exponential backoff for retries
else:
print("Max retries reached. Translation failed.")
return None
def get_image_element(self, url):
try:
translator = Translator()
image = Image.open(requests.get(url, stream=True, timeout=30).raw)
if image.mode != "RGB":
image = image.convert("RGB")
inputs = self.processor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
except Exception as e:
print(f"该图片解析失败: {e}")
return None, None
englist_label_set = set()
chinese_lable_set = set()
for label in results['labels']:
label = self.model.config.id2label[label.item()]
if label:
englist_label_set.add(label)
chinese_label = self.translate_with_retry(translator, label, src='en', dest='zh-CN')
if chinese_label:
chinese_lable_set.add(chinese_label)
if not englist_label_set:
results = self.image_classification(url)
for score, label in zip(results['scores'], results['labels']):
if score > 0.5:
english_label = self.translate_with_retry(translator, label, src='zh-CN', dest='en')
if english_label:
englist_label_set.add(english_label)
chinese_lable_set.add(label)
if englist_label_set:
return ','.join(englist_label_set), ','.join(chinese_lable_set)
else:
return None, None
def run(self):
self.rea_and_prepare_data()
self.handle_image_history()
if __name__ == '__main__':
handle_obj = ImageRecognitionHistory()
handle_obj.run()