"""
提取图片特征信息的模型：vgg16
"""
import io

import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg


class VGGNet:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model_vgg = VGG16(weights=self.weight,
                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
                               pooling=self.pooling, include_top=False)
        # self.model_vgg.predict(np.zeros((1, 224, 224, 3)))

    # 提取vgg16最后一层卷积特征
    def vgg_extract_feat(self, file, file_type='file'):
        # if file_type == 'bytes':
        #     img = image.load_img(io.BytesIO(file.read()), target_size=(self.input_shape[0], self.input_shape[1]))
        # else:
        #     # file_type = 'file'
        #     img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input_vgg(img)
        feat = self.model_vgg.predict(img)
        # print(feat.shape)
        norm_feat = feat[0] / LA.norm(feat[0])
        # print("norm_feat:", norm_feat)
        return list(norm_feat)

    # # 提取vgg16最后一层卷积特征
    # def vgg_extract_feat(self, image_bytes):
    #     try:
    #         img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
    #         img = image.img_to_array(img)
    #         img = np.expand_dims(img, axis=0)
    #         img = preprocess_input_vgg(img)
    #         feat = self.model_vgg.predict(img)
    #         norm_feat = feat[0] / LA.norm(feat[0])
    #         return norm_feat.tolist()
    #     except Exception as e:
    #         print(e, traceback.format_exc())
    #             return list(np.zeros(shape=(512,)))


if __name__ == '__main__':
    handle_obj = VGGNet()
    handle_obj.vgg_extract_feat(file='')