1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import threading
import time
import traceback
import socket
import uuid
import numpy as np
import pandas as pd
import redis
import logging
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
from sqlalchemy import text
from vgg_model import VGGNet
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
class AsinImageFeatures(Templates):
def __init__(self, site_name='us', asin_type=1, thread_num=10, limit=1000):
super(AsinImageFeatures, self).__init__()
self.site_name = site_name
self.asin_type = asin_type
self.thread_num = thread_num
self.limit = limit
self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
self.local_name = f"{self.site_name}_asin_image_features"
self.vgg_model = VGGNet()
self.hostname = socket.gethostname()
def acquire_lock(self, lock_name, timeout=10):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
print("self.hostname:", self.hostname)
with self.engine_srs.begin() as conn:
sql_read = text(f"SELECT id, asin, local_path, asin_type FROM selection.asin_image_local_path WHERE site_name='{self.site_name}' and asin_type={self.asin_type} and state=1 LIMIT {self.limit};")
# result = conn.execute(sql_read)
# df = pd.DataFrame(result.fetchall())
df = pd.read_sql(sql=sql_read, con=self.engine_srs)
id_list = list(df.id)
print(f"sql_read: {sql_read}, {df.shape}", id_list[:10])
if id_list:
sql_update = text(f"UPDATE selection.asin_image_local_path SET state=2 WHERE id IN ({','.join(map(str, id_list))});")
print("sql_update:", sql_update)
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(5) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(5)
self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
continue
def handle_data(self, df, thread_id):
id_list = list(df.id)
asin_list = list(df.asin)
local_path_list = list(df.local_path)
data_list = []
for id, asin, local_path in zip(id_list, asin_list, local_path_list):
index = id_list.index(id)
print(f"thread_id, index, id, asin, local_path: {thread_id, index, id, asin, local_path}")
if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
local_path = local_path.replace("/mnt", "/home")
try:
features = self.vgg_model.vgg_extract_feat(file=local_path)
except Exception as e:
print(e, traceback.format_exc())
features = list(np.zeros(shape=(512,)))
data_list.append([id, asin, str(features), self.asin_type, self.site_name])
columns = ['id', 'asin', 'features', 'asin_type', 'site_name']
df_save = pd.DataFrame(data_list, columns=columns)
return df_save
def save_data(self, df):
df.to_sql("asin_image_features", con=self.engine_srs, if_exists="append", index=False)
with self.engine_srs.begin() as conn:
id_tuple = tuple(df.id)
if id_tuple:
id_tuple_str = f"({id_tuple[0]})" if len(id_tuple) == 1 else f"{id_tuple}"
sql_update = f"update selection.asin_image_local_path set state=3 where id in {id_tuple_str};"
print(f"sql_update: {sql_update}")
conn.execute(sql_update)
def run(self, thread_id=1):
while True:
try:
df = self.read_data()
if df.shape[0]:
df_save = self.handle_data(df=df, thread_id=thread_id)
self.save_data(df=df_save)
# break
else:
break
except Exception as e:
print(e, traceback.format_exc())
self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
self.vgg_model = VGGNet()
time.sleep(20)
continue
def run_thread(self):
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run, args=(thread_id, ))
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
# handle_obj = PicturesFeatures(self_flag='_self')
# site_name = int(sys.argv[1]) # 参数1:站点
site_name = 'us'
asin_type = 1
thread_num = 3
limit = 500
handle_obj = AsinImageFeatures(site_name=site_name, asin_type=asin_type, thread_num=thread_num, limit=limit)
# handle_obj.run()
handle_obj.run_thread()