extract.py 5.85 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
import os
import sys
import re
import traceback

import numpy as np
import pandas as pd
import socket
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.utils import Utils
from vgg_model import VGGNet
from transfer import TransferImages


class ExtractFeatures(Utils):

    def __init__(self, site_name="us"):
        super(ExtractFeatures, self).__init__()
        self.vgg_model = VGGNet()
        self.transfer = TransferImages()
        self.connection(db_type="milvus", db_conn="hadoop10")  # 建立milvus连接
        self.connection(db_type="mysql", db_conn="aliyun")  # 建立mysql连接
        self.site_name = site_name
        self.id = int()
        self.cate_1_id = str()
        self.cate_current_id = str()
        self.img_share_path = str()  # 图片共享路径
        self.img_path_list = list()  # 获取共享路径下所有图片
        self.img_data_list = list()  # 存储图片信息
        self.img_dim = 512
        self.db_read = f"{self.site_name}_asin_extract_cate"
        self.db_save = f"{self.site_name}_asin_extract_features"
        self.df_read = pd.DataFrame()  # 数据库读取的df对象
        self.df_save = pd.DataFrame()  # 数据库存储的df对象
        self.hostname = socket.gethostname()

    @staticmethod
    def get_img_path(img_share_path):
        img_path_list = list(os.listdir(img_share_path))
        return [img_share_path+img for img in img_path_list]

    def read_and_update(self):
        with self.engine.begin() as conn:
            sql_read = f"select id, cate_1_id, cate_current_id, state from {self.db_read} where state=1 limit 1 for update;"
            print("sql_read:", sql_read)
            a = conn.execute(sql_read)
            self.df_read = pd.DataFrame(a, columns=['id', 'cate_1_id', 'cate_current_id', 'state'])
            if self.df_read.shape[0] == 1:
                self.id = list(self.df_read.id)[0]
                self.cate_1_id = list(self.df_read.cate_1_id)[0]
                self.cate_current_id = list(self.df_read.cate_current_id)[0]
                sql_update = f"update {self.db_read} set state=2 where id={self.id}"
                print("sql_update:", sql_update)
                conn.execute(sql_update)
            else:
                quit()

    def transfer_images(self):
        """
        图片存储在hadoop6上,因此在其他物理机上提取图片之前,先从hadoop6传输过去
        :return: None
        """
        if self.hostname != "hadoop6":
            print("hello, 进入图片传输:", self.hostname)
            transfer_obj = TransferImages(cate_1_id=self.cate_1_id, cate_current_id=self.cate_current_id)
            transfer_obj.transfer_images()
        else:
            print("本地hadoop6主服务器无需传输")

    def extract_features(self):
        img_list_len = len(self.img_path_list)
        for img_path in self.img_path_list:
            index = self.img_path_list.index(img_path)
            asin = re.findall(".*/(.*?).jpg", img_path)[0] if re.findall(".*/(.*?).jpg", img_path) else None
            print(f"asin:{asin}, 当前提取图片index+1: {index+1}, 总提取图片数量: {img_list_len}, 提取进度: {round((index+1)/img_list_len, 4)}")
            try:
                feats = self.vgg_model.vgg_extract_feat(img_path=img_path)
                state = 1
                # return self.p(img_path)
            except Exception as e:
                print("RuntimeError: Read image", img_path, e)
                feats = list(np.zeros(shape=(self.img_dim,)))
                state = 2
            self.img_data_list.append([asin, str(feats), state])

    def sava_data(self):
        while True:
            try:
                print("存储数据", len(self.img_data_list))
                self.df_save = pd.DataFrame(self.img_data_list, columns=['asin', 'img_vector', 'state'])
                self.df_save.to_sql(f"{self.db_save}", con=self.engine, if_exists='append', index=False)
                break
            except Exception as e:
                print("存储异常,重新连接存储:", e, traceback.format_exc())
                self.connection(db_type="mysql", db_conn="aliyun")  # mysql连接
                continue

    def delete_and_update(self):
        while True:
            try:
                #  判断当前服务器是否是hadoop6, 如果是则跳过, 否则删除当前传输的图片
                if self.hostname != "hadoop6":
                    print(f"当前hostname为{self.hostname}, 不是hadoop6, 因此删除传输的图片")
                    os.system(f"rm -rf {self.img_share_path}")
                #  更改已经成功提取当前分类的主键id
                with self.engine.begin() as conn:
                    sql_update = f"update {self.db_read} set state=3 where id={self.id}"
                    print("sql_update:", sql_update)
                    conn.execute(sql_update)
                break
            except Exception as e:
                print("存储异常,重新连接存储:", e, traceback.format_exc())
                self.connection(db_type="mysql", db_conn="aliyun")  # mysql连接
                continue

    def run(self):
        self.read_and_update()  # 获取分类id
        self.transfer_images()  # 根据分类id传输图片数据
        self.img_share_path = f"/home/data/{self.site_name}/{self.cate_1_id}/{self.cate_current_id}/"  # 图片共享路径
        self.img_path_list = self.get_img_path(img_share_path=self.img_share_path)  # 获取共享路径下所有图片
        self.img_data_list = []  # 存储图片特征的列表
        self.extract_features()  # 提取图片特征
        self.sava_data()  # 存储图片特征
        self.delete_and_update()  # 更新数据库


if __name__ == '__main__':
    print("hello")
    # handle_obj = ExtractFeatures(img_path=rf"/tmp/pycharm_project_216/data/img")
    handle_obj = ExtractFeatures()
    while True:
        handle_obj.run()