extract_features.py 4.42 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
# import cv2
# import numpy as np
# from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
# from tensorflow.keras.models import Model
# from tensorflow.keras.preprocessing import image
# from pyspark.sql.functions import udf
# from pyspark.sql.types import StringType
# from pyspark import SparkFiles
#
# # Load ResNet50 model
# base_model = ResNet50(weights='imagenet')
# model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output)
#
#
# def extract_features(file_path):
#     try:
#         # Load image
#         img = image.load_img(file_path, target_size=(224, 224))
#
#         # Preprocess image
#         x = image.img_to_array(img)
#         x = np.expand_dims(x, axis=0)
#         x = preprocess_input(x)
#
#         # Extract features
#         features = model.predict(x)
#         features = features.flatten()
#
#         return features.tolist()
#
#     except Exception as e:
#         return [str(e)]
#
#
# # Create UDF
# extract_features_udf = udf(extract_features, StringType())
#
# # Get list of images from HDFS
# images_rdd = sc.wholeTextFiles("hdfs://<your-hdfs-path>/*.jpg")
# images_df = images_rdd.toDF(["path", "image"])
#
# # Extract features
# features_df = images_df.withColumn("features", extract_features_udf(images_df["path"]))
#
# features_df.show()


import ast
import os
import sys
import traceback

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates

import io
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg


from pyspark.sql.functions import udf, lit
from pyspark.sql.types import ArrayType, FloatType


class VGGNet:
    def __init__(self):
        # ...省略其它代码...
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model_vgg = VGG16(weights=self.weight,
                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
                               pooling=self.pooling, include_top=False)

    # 提取vgg16最后一层卷积特征
    def vgg_extract_feat(self, image_bytes):
        try:
            img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
            img = image.img_to_array(img)
            img = np.expand_dims(img, axis=0)
            img = preprocess_input_vgg(img)
            feat = self.model_vgg.predict(img)
            norm_feat = feat[0] / LA.norm(feat[0])
            return norm_feat.tolist()
        except Exception as e:
            print(e, traceback.format_exc())
            return list(np.zeros(shape=(512,)))


# 创建VGGNet的实例
vgg_net = VGGNet()

# 创建UDF
extract_features_udf = udf(vgg_net.vgg_extract_feat, ArrayType(FloatType()))

# # 使用UDF
# df = df.withColumn("features", extract_features_udf("image"))


class ExtractFeatures(Templates):

    def __init__(self, site_name='us', hdfs_pictures_path=f"/home/ffman/pictures/us/*/*/*/*/*/*/*.jpg"):
        super(ExtractFeatures, self).__init__()
        self.site_name = site_name
        self.db_save = f'extract_features'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_features = self.spark.sql(f"select 1+1;")
        self.hdfs_pictures_path = hdfs_pictures_path
        self.partitions_by = ['site_name']
        self.partitions_num = 10

    def read_data(self):
        self.df_features = self.spark.sparkContext.binaryFiles(self.hdfs_pictures_path).toDF(["path", "image"])
        # self.df_features.show(10)
        print("self.df_features.count():", self.df_features.count())

    def handle_data(self):
        # 使用UDF
        self.df_features = self.df_features.withColumn("features", extract_features_udf("image")).cache()
        # self.df_features.show(10)

    def run(self):
        self.read_data()
        self.handle_data()
        self.df_save = self.df_features
        self.df_save = self.df_save.withColumn('site_name', lit(self.site_name))
        self.save_data()


if __name__ == '__main__':
    handle_obj = ExtractFeatures()
    handle_obj.run()