extract.py 5.85 KB
import os
import sys
import re
import traceback

import numpy as np
import pandas as pd
import socket
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.utils import Utils
from vgg_model import VGGNet
from transfer import TransferImages


class ExtractFeatures(Utils):

    def __init__(self, site_name="us"):
        super(ExtractFeatures, self).__init__()
        self.vgg_model = VGGNet()
        self.transfer = TransferImages()
        self.connection(db_type="milvus", db_conn="hadoop10")  # 建立milvus连接
        self.connection(db_type="mysql", db_conn="aliyun")  # 建立mysql连接
        self.site_name = site_name
        self.id = int()
        self.cate_1_id = str()
        self.cate_current_id = str()
        self.img_share_path = str()  # 图片共享路径
        self.img_path_list = list()  # 获取共享路径下所有图片
        self.img_data_list = list()  # 存储图片信息
        self.img_dim = 512
        self.db_read = f"{self.site_name}_asin_extract_cate"
        self.db_save = f"{self.site_name}_asin_extract_features"
        self.df_read = pd.DataFrame()  # 数据库读取的df对象
        self.df_save = pd.DataFrame()  # 数据库存储的df对象
        self.hostname = socket.gethostname()

    @staticmethod
    def get_img_path(img_share_path):
        img_path_list = list(os.listdir(img_share_path))
        return [img_share_path+img for img in img_path_list]

    def read_and_update(self):
        with self.engine.begin() as conn:
            sql_read = f"select id, cate_1_id, cate_current_id, state from {self.db_read} where state=1 limit 1 for update;"
            print("sql_read:", sql_read)
            a = conn.execute(sql_read)
            self.df_read = pd.DataFrame(a, columns=['id', 'cate_1_id', 'cate_current_id', 'state'])
            if self.df_read.shape[0] == 1:
                self.id = list(self.df_read.id)[0]
                self.cate_1_id = list(self.df_read.cate_1_id)[0]
                self.cate_current_id = list(self.df_read.cate_current_id)[0]
                sql_update = f"update {self.db_read} set state=2 where id={self.id}"
                print("sql_update:", sql_update)
                conn.execute(sql_update)
            else:
                quit()

    def transfer_images(self):
        """
        图片存储在hadoop6上,因此在其他物理机上提取图片之前,先从hadoop6传输过去
        :return: None
        """
        if self.hostname != "hadoop6":
            print("hello, 进入图片传输:", self.hostname)
            transfer_obj = TransferImages(cate_1_id=self.cate_1_id, cate_current_id=self.cate_current_id)
            transfer_obj.transfer_images()
        else:
            print("本地hadoop6主服务器无需传输")

    def extract_features(self):
        img_list_len = len(self.img_path_list)
        for img_path in self.img_path_list:
            index = self.img_path_list.index(img_path)
            asin = re.findall(".*/(.*?).jpg", img_path)[0] if re.findall(".*/(.*?).jpg", img_path) else None
            print(f"asin:{asin}, 当前提取图片index+1: {index+1}, 总提取图片数量: {img_list_len}, 提取进度: {round((index+1)/img_list_len, 4)}")
            try:
                feats = self.vgg_model.vgg_extract_feat(img_path=img_path)
                state = 1
                # return self.p(img_path)
            except Exception as e:
                print("RuntimeError: Read image", img_path, e)
                feats = list(np.zeros(shape=(self.img_dim,)))
                state = 2
            self.img_data_list.append([asin, str(feats), state])

    def sava_data(self):
        while True:
            try:
                print("存储数据", len(self.img_data_list))
                self.df_save = pd.DataFrame(self.img_data_list, columns=['asin', 'img_vector', 'state'])
                self.df_save.to_sql(f"{self.db_save}", con=self.engine, if_exists='append', index=False)
                break
            except Exception as e:
                print("存储异常,重新连接存储:", e, traceback.format_exc())
                self.connection(db_type="mysql", db_conn="aliyun")  # mysql连接
                continue

    def delete_and_update(self):
        while True:
            try:
                #  判断当前服务器是否是hadoop6, 如果是则跳过, 否则删除当前传输的图片
                if self.hostname != "hadoop6":
                    print(f"当前hostname为{self.hostname}, 不是hadoop6, 因此删除传输的图片")
                    os.system(f"rm -rf {self.img_share_path}")
                #  更改已经成功提取当前分类的主键id
                with self.engine.begin() as conn:
                    sql_update = f"update {self.db_read} set state=3 where id={self.id}"
                    print("sql_update:", sql_update)
                    conn.execute(sql_update)
                break
            except Exception as e:
                print("存储异常,重新连接存储:", e, traceback.format_exc())
                self.connection(db_type="mysql", db_conn="aliyun")  # mysql连接
                continue

    def run(self):
        self.read_and_update()  # 获取分类id
        self.transfer_images()  # 根据分类id传输图片数据
        self.img_share_path = f"/home/data/{self.site_name}/{self.cate_1_id}/{self.cate_current_id}/"  # 图片共享路径
        self.img_path_list = self.get_img_path(img_share_path=self.img_share_path)  # 获取共享路径下所有图片
        self.img_data_list = []  # 存储图片特征的列表
        self.extract_features()  # 提取图片特征
        self.sava_data()  # 存储图片特征
        self.delete_and_update()  # 更新数据库


if __name__ == '__main__':
    print("hello")
    # handle_obj = ExtractFeatures(img_path=rf"/tmp/pycharm_project_216/data/img")
    handle_obj = ExtractFeatures()
    while True:
        handle_obj.run()