1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
import re
import traceback
import numpy as np
import pandas as pd
import socket
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.utils import Utils
from vgg_model import VGGNet
from transfer import TransferImages
class ExtractFeatures(Utils):
def __init__(self, site_name="us"):
super(ExtractFeatures, self).__init__()
self.vgg_model = VGGNet()
self.transfer = TransferImages()
self.connection(db_type="milvus", db_conn="hadoop10") # 建立milvus连接
self.connection(db_type="mysql", db_conn="aliyun") # 建立mysql连接
self.site_name = site_name
self.id = int()
self.cate_1_id = str()
self.cate_current_id = str()
self.img_share_path = str() # 图片共享路径
self.img_path_list = list() # 获取共享路径下所有图片
self.img_data_list = list() # 存储图片信息
self.img_dim = 512
self.db_read = f"{self.site_name}_asin_extract_cate"
self.db_save = f"{self.site_name}_asin_extract_features"
self.df_read = pd.DataFrame() # 数据库读取的df对象
self.df_save = pd.DataFrame() # 数据库存储的df对象
self.hostname = socket.gethostname()
@staticmethod
def get_img_path(img_share_path):
img_path_list = list(os.listdir(img_share_path))
return [img_share_path+img for img in img_path_list]
def read_and_update(self):
with self.engine.begin() as conn:
sql_read = f"select id, cate_1_id, cate_current_id, state from {self.db_read} where state=1 limit 1 for update;"
print("sql_read:", sql_read)
a = conn.execute(sql_read)
self.df_read = pd.DataFrame(a, columns=['id', 'cate_1_id', 'cate_current_id', 'state'])
if self.df_read.shape[0] == 1:
self.id = list(self.df_read.id)[0]
self.cate_1_id = list(self.df_read.cate_1_id)[0]
self.cate_current_id = list(self.df_read.cate_current_id)[0]
sql_update = f"update {self.db_read} set state=2 where id={self.id}"
print("sql_update:", sql_update)
conn.execute(sql_update)
else:
quit()
def transfer_images(self):
"""
图片存储在hadoop6上,因此在其他物理机上提取图片之前,先从hadoop6传输过去
:return: None
"""
if self.hostname != "hadoop6":
print("hello, 进入图片传输:", self.hostname)
transfer_obj = TransferImages(cate_1_id=self.cate_1_id, cate_current_id=self.cate_current_id)
transfer_obj.transfer_images()
else:
print("本地hadoop6主服务器无需传输")
def extract_features(self):
img_list_len = len(self.img_path_list)
for img_path in self.img_path_list:
index = self.img_path_list.index(img_path)
asin = re.findall(".*/(.*?).jpg", img_path)[0] if re.findall(".*/(.*?).jpg", img_path) else None
print(f"asin:{asin}, 当前提取图片index+1: {index+1}, 总提取图片数量: {img_list_len}, 提取进度: {round((index+1)/img_list_len, 4)}")
try:
feats = self.vgg_model.vgg_extract_feat(img_path=img_path)
state = 1
# return self.p(img_path)
except Exception as e:
print("RuntimeError: Read image", img_path, e)
feats = list(np.zeros(shape=(self.img_dim,)))
state = 2
self.img_data_list.append([asin, str(feats), state])
def sava_data(self):
while True:
try:
print("存储数据", len(self.img_data_list))
self.df_save = pd.DataFrame(self.img_data_list, columns=['asin', 'img_vector', 'state'])
self.df_save.to_sql(f"{self.db_save}", con=self.engine, if_exists='append', index=False)
break
except Exception as e:
print("存储异常,重新连接存储:", e, traceback.format_exc())
self.connection(db_type="mysql", db_conn="aliyun") # mysql连接
continue
def delete_and_update(self):
while True:
try:
# 判断当前服务器是否是hadoop6, 如果是则跳过, 否则删除当前传输的图片
if self.hostname != "hadoop6":
print(f"当前hostname为{self.hostname}, 不是hadoop6, 因此删除传输的图片")
os.system(f"rm -rf {self.img_share_path}")
# 更改已经成功提取当前分类的主键id
with self.engine.begin() as conn:
sql_update = f"update {self.db_read} set state=3 where id={self.id}"
print("sql_update:", sql_update)
conn.execute(sql_update)
break
except Exception as e:
print("存储异常,重新连接存储:", e, traceback.format_exc())
self.connection(db_type="mysql", db_conn="aliyun") # mysql连接
continue
def run(self):
self.read_and_update() # 获取分类id
self.transfer_images() # 根据分类id传输图片数据
self.img_share_path = f"/home/data/{self.site_name}/{self.cate_1_id}/{self.cate_current_id}/" # 图片共享路径
self.img_path_list = self.get_img_path(img_share_path=self.img_share_path) # 获取共享路径下所有图片
self.img_data_list = [] # 存储图片特征的列表
self.extract_features() # 提取图片特征
self.sava_data() # 存储图片特征
self.delete_and_update() # 更新数据库
if __name__ == '__main__':
print("hello")
# handle_obj = ExtractFeatures(img_path=rf"/tmp/pycharm_project_216/data/img")
handle_obj = ExtractFeatures()
while True:
handle_obj.run()