pictures_features.py 5 KB
import os
import re
import sys
import time
import traceback
import socket

import numpy as np
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from sqlalchemy import text
from vgg_model import VGGNet


class PicturesFeatures():

    def __init__(self, site_name='us', pictures_num=1000, self_flag='_self'):
        self.site_name = site_name
        self.pictures_num = pictures_num
        self.self_flag = self_flag
        self.engine = TemplatesMysql().engine_pg
        self.vgg_model = VGGNet()
        self.tn_pics_local_path = f"{self.site_name}_pictures_local_path_copy"
        # self.tn_pics_local_path = f"{self.site_name}_pictures_local_path{self_flag}"
        self.tn_pics_features = f"{self.site_name}_pictures_features{self_flag}"
        self.id_list = list()
        self.asin_list = list()
        self.local_path_list = list()
        self.df_save = pd.DataFrame()
        self.hostname = socket.gethostname()

    def read_data(self):
        print("self.hostname:", self.hostname)
        with self.engine.begin() as conn:
            sql_read = text(f"SELECT * FROM {self.tn_pics_local_path} WHERE state=1 LIMIT {self.pictures_num} FOR UPDATE;")
            print("sql_read:", sql_read)
            result = conn.execute(sql_read)
            df = pd.DataFrame(result.fetchall())
            print("df.shape:", df.shape)

            if df.shape[0] == 0:
                print("没有数据需要提取, 退出")
                quit()
            else:
                df.columns = result.keys()
                self.id_list = list(df.id)
                self.asin_list = list(df.asin)
                self.local_path_list = list(df.local_path)
                if self.id_list:
                    sql_update = text(f"UPDATE {self.tn_pics_local_path} SET state=2 WHERE id IN ({','.join(map(str, self.id_list))});")
                    print("sql_update:", sql_update)
                    conn.execute(sql_update)

    def extract_features(self):
        data_list = []
        for id, asin, local_path in zip(self.id_list, self.asin_list, self.local_path_list):
            if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
                local_path = local_path.replace("/mnt", "/home")
            print("id, asin, local_path:", id, asin, local_path)
            try:
                features = self.vgg_model.vgg_extract_feat(file=local_path)
            except Exception as e:
                print(e, traceback.format_exc())
                features = list(np.zeros(shape=(512,)))
                # time.sleep(5)
            if self.self_flag:
                platform = re.findall(f"/mnt/data/img_data/(.*?)/", local_path)
                site = re.findall(f"/mnt/data/img_data/.*?/(.*?)/", local_path)
                data_list.append([id, asin, str(features), platform, site])
                columns = ['id', 'asin', 'features', 'platform', 'site']
                # data_list.append([asin, str(features), platform, site])
                # columns = ['asin', 'features', 'platform', 'site']
            else:
                data_list.append([id, asin, str(features)])
                columns = ['id', 'asin', 'features']
                # data_list.append([asin, str(features)])
                # columns = ['asin', 'features']
        self.df_save = pd.DataFrame(data_list, columns=columns)

    def save_and_update_data(self):
        # if self.self_flag == '_self':
        #     self.df_save['platform'] = self.platform
        asin_tuple = tuple(self.df_save.asin)
        if self.df_save.shape == 1:
            asin_tuple = f"('{tuple(self.df_save.asin)[0]}')"

        with self.engine.begin() as conn:
            sql_delete = text(
                f"DELETE from {self.tn_pics_features} where asin IN {asin_tuple};")
            print("sql_update:", sql_delete)
            conn.execute(sql_delete)
        self.df_save.drop(columns=["id"], inplace=True)
        print(self.df_save.columns)
        self.df_save.to_sql(self.tn_pics_features, con=self.engine, if_exists="append", index=False)

        with self.engine.begin() as conn:
            sql_update = text(
                f"UPDATE {self.tn_pics_local_path} SET state=3 WHERE state=2 and id IN ({','.join(map(str, self.id_list))});")
            print("sql_update:", sql_update)
            conn.execute(sql_update)

    def run(self):
        while True:
            try:
                self.read_data()
                self.extract_features()
                self.save_and_update_data()
            except Exception as e:
                print(e, traceback.format_exc())
                self.engine = TemplatesMysql().engine_pg
                self.vgg_model = VGGNet()
                time.sleep(20)
                continue


if __name__ == '__main__':
    # handle_obj = PicturesFeatures(self_flag='_self')
    pictures_num = int(sys.argv[1])  # 参数1:站点
    handle_obj = PicturesFeatures(pictures_num=pictures_num, self_flag='')
    handle_obj.run()