import os import re import sys import time import traceback import socket import numpy as np import pandas as pd os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" sys.path.append(os.path.dirname(sys.path[0])) # 上级目录 from utils.templates_mysql import TemplatesMysql from sqlalchemy import text from vgg_model import VGGNet class PicturesFeatures(): def __init__(self, site_name='us', pictures_num=1000, self_flag='_self'): self.site_name = site_name self.pictures_num = pictures_num self.self_flag = self_flag self.engine = TemplatesMysql().engine_pg self.vgg_model = VGGNet() self.tn_pics_local_path = f"{self.site_name}_pictures_local_path_copy" # self.tn_pics_local_path = f"{self.site_name}_pictures_local_path{self_flag}" self.tn_pics_features = f"{self.site_name}_pictures_features{self_flag}" self.id_list = list() self.asin_list = list() self.local_path_list = list() self.df_save = pd.DataFrame() self.hostname = socket.gethostname() def read_data(self): print("self.hostname:", self.hostname) with self.engine.begin() as conn: sql_read = text(f"SELECT * FROM {self.tn_pics_local_path} WHERE state=1 LIMIT {self.pictures_num} FOR UPDATE;") print("sql_read:", sql_read) result = conn.execute(sql_read) df = pd.DataFrame(result.fetchall()) print("df.shape:", df.shape) if df.shape[0] == 0: print("没有数据需要提取, 退出") quit() else: df.columns = result.keys() self.id_list = list(df.id) self.asin_list = list(df.asin) self.local_path_list = list(df.local_path) if self.id_list: sql_update = text(f"UPDATE {self.tn_pics_local_path} SET state=2 WHERE id IN ({','.join(map(str, self.id_list))});") print("sql_update:", sql_update) conn.execute(sql_update) def extract_features(self): data_list = [] for id, asin, local_path in zip(self.id_list, self.asin_list, self.local_path_list): if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']: local_path = local_path.replace("/mnt", "/home") print("id, asin, local_path:", id, asin, local_path) try: features = self.vgg_model.vgg_extract_feat(file=local_path) except Exception as e: print(e, traceback.format_exc()) features = list(np.zeros(shape=(512,))) # time.sleep(5) if self.self_flag: platform = re.findall(f"/mnt/data/img_data/(.*?)/", local_path) site = re.findall(f"/mnt/data/img_data/.*?/(.*?)/", local_path) data_list.append([id, asin, str(features), platform, site]) columns = ['id', 'asin', 'features', 'platform', 'site'] # data_list.append([asin, str(features), platform, site]) # columns = ['asin', 'features', 'platform', 'site'] else: data_list.append([id, asin, str(features)]) columns = ['id', 'asin', 'features'] # data_list.append([asin, str(features)]) # columns = ['asin', 'features'] self.df_save = pd.DataFrame(data_list, columns=columns) def save_and_update_data(self): # if self.self_flag == '_self': # self.df_save['platform'] = self.platform asin_tuple = tuple(self.df_save.asin) if self.df_save.shape == 1: asin_tuple = f"('{tuple(self.df_save.asin)[0]}')" with self.engine.begin() as conn: sql_delete = text( f"DELETE from {self.tn_pics_features} where asin IN {asin_tuple};") print("sql_update:", sql_delete) conn.execute(sql_delete) self.df_save.drop(columns=["id"], inplace=True) print(self.df_save.columns) self.df_save.to_sql(self.tn_pics_features, con=self.engine, if_exists="append", index=False) with self.engine.begin() as conn: sql_update = text( f"UPDATE {self.tn_pics_local_path} SET state=3 WHERE state=2 and id IN ({','.join(map(str, self.id_list))});") print("sql_update:", sql_update) conn.execute(sql_update) def run(self): while True: try: self.read_data() self.extract_features() self.save_and_update_data() except Exception as e: print(e, traceback.format_exc()) self.engine = TemplatesMysql().engine_pg self.vgg_model = VGGNet() time.sleep(20) continue if __name__ == '__main__': # handle_obj = PicturesFeatures(self_flag='_self') pictures_num = int(sys.argv[1]) # 参数1:站点 handle_obj = PicturesFeatures(pictures_num=pictures_num, self_flag='') handle_obj.run()