pictures_features.py 5 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
import os
import re
import sys
import time
import traceback
import socket

import numpy as np
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from sqlalchemy import text
from vgg_model import VGGNet


class PicturesFeatures():

    def __init__(self, site_name='us', pictures_num=1000, self_flag='_self'):
        self.site_name = site_name
        self.pictures_num = pictures_num
        self.self_flag = self_flag
        self.engine = TemplatesMysql().engine_pg
        self.vgg_model = VGGNet()
        self.tn_pics_local_path = f"{self.site_name}_pictures_local_path_copy"
        # self.tn_pics_local_path = f"{self.site_name}_pictures_local_path{self_flag}"
        self.tn_pics_features = f"{self.site_name}_pictures_features{self_flag}"
        self.id_list = list()
        self.asin_list = list()
        self.local_path_list = list()
        self.df_save = pd.DataFrame()
        self.hostname = socket.gethostname()

    def read_data(self):
        print("self.hostname:", self.hostname)
        with self.engine.begin() as conn:
            sql_read = text(f"SELECT * FROM {self.tn_pics_local_path} WHERE state=1 LIMIT {self.pictures_num} FOR UPDATE;")
            print("sql_read:", sql_read)
            result = conn.execute(sql_read)
            df = pd.DataFrame(result.fetchall())
            print("df.shape:", df.shape)

            if df.shape[0] == 0:
                print("没有数据需要提取, 退出")
                quit()
            else:
                df.columns = result.keys()
                self.id_list = list(df.id)
                self.asin_list = list(df.asin)
                self.local_path_list = list(df.local_path)
                if self.id_list:
                    sql_update = text(f"UPDATE {self.tn_pics_local_path} SET state=2 WHERE id IN ({','.join(map(str, self.id_list))});")
                    print("sql_update:", sql_update)
                    conn.execute(sql_update)

    def extract_features(self):
        data_list = []
        for id, asin, local_path in zip(self.id_list, self.asin_list, self.local_path_list):
            if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
                local_path = local_path.replace("/mnt", "/home")
            print("id, asin, local_path:", id, asin, local_path)
            try:
                features = self.vgg_model.vgg_extract_feat(file=local_path)
            except Exception as e:
                print(e, traceback.format_exc())
                features = list(np.zeros(shape=(512,)))
                # time.sleep(5)
            if self.self_flag:
                platform = re.findall(f"/mnt/data/img_data/(.*?)/", local_path)
                site = re.findall(f"/mnt/data/img_data/.*?/(.*?)/", local_path)
                data_list.append([id, asin, str(features), platform, site])
                columns = ['id', 'asin', 'features', 'platform', 'site']
                # data_list.append([asin, str(features), platform, site])
                # columns = ['asin', 'features', 'platform', 'site']
            else:
                data_list.append([id, asin, str(features)])
                columns = ['id', 'asin', 'features']
                # data_list.append([asin, str(features)])
                # columns = ['asin', 'features']
        self.df_save = pd.DataFrame(data_list, columns=columns)

    def save_and_update_data(self):
        # if self.self_flag == '_self':
        #     self.df_save['platform'] = self.platform
        asin_tuple = tuple(self.df_save.asin)
        if self.df_save.shape == 1:
            asin_tuple = f"('{tuple(self.df_save.asin)[0]}')"

        with self.engine.begin() as conn:
            sql_delete = text(
                f"DELETE from {self.tn_pics_features} where asin IN {asin_tuple};")
            print("sql_update:", sql_delete)
            conn.execute(sql_delete)
        self.df_save.drop(columns=["id"], inplace=True)
        print(self.df_save.columns)
        self.df_save.to_sql(self.tn_pics_features, con=self.engine, if_exists="append", index=False)

        with self.engine.begin() as conn:
            sql_update = text(
                f"UPDATE {self.tn_pics_local_path} SET state=3 WHERE state=2 and id IN ({','.join(map(str, self.id_list))});")
            print("sql_update:", sql_update)
            conn.execute(sql_update)

    def run(self):
        while True:
            try:
                self.read_data()
                self.extract_features()
                self.save_and_update_data()
            except Exception as e:
                print(e, traceback.format_exc())
                self.engine = TemplatesMysql().engine_pg
                self.vgg_model = VGGNet()
                time.sleep(20)
                continue


if __name__ == '__main__':
    # handle_obj = PicturesFeatures(self_flag='_self')
    pictures_num = int(sys.argv[1])  # 参数1:站点
    handle_obj = PicturesFeatures(pictures_num=pictures_num, self_flag='')
    handle_obj.run()