# import cv2 # import numpy as np # from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input # from tensorflow.keras.models import Model # from tensorflow.keras.preprocessing import image # from pyspark.sql.functions import udf # from pyspark.sql.types import StringType # from pyspark import SparkFiles # # # Load ResNet50 model # base_model = ResNet50(weights='imagenet') # model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output) # # # def extract_features(file_path): # try: # # Load image # img = image.load_img(file_path, target_size=(224, 224)) # # # Preprocess image # x = image.img_to_array(img) # x = np.expand_dims(x, axis=0) # x = preprocess_input(x) # # # Extract features # features = model.predict(x) # features = features.flatten() # # return features.tolist() # # except Exception as e: # return [str(e)] # # # # Create UDF # extract_features_udf = udf(extract_features, StringType()) # # # Get list of images from HDFS # images_rdd = sc.wholeTextFiles("hdfs://<your-hdfs-path>/*.jpg") # images_df = images_rdd.toDF(["path", "image"]) # # # Extract features # features_df = images_df.withColumn("features", extract_features_udf(images_df["path"])) # # features_df.show() import ast import os import sys import traceback os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" sys.path.append(os.path.dirname(sys.path[0])) # 上级目录 from utils.templates import Templates # from ..utils.templates import Templates import io import numpy as np import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' from numpy import linalg as LA from keras.applications.vgg16 import VGG16 from tensorflow.keras.preprocessing import image from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg from pyspark.sql.functions import udf, lit from pyspark.sql.types import ArrayType, FloatType class VGGNet: def __init__(self): # ...省略其它代码... self.input_shape = (224, 224, 3) self.weight = 'imagenet' self.pooling = 'max' self.model_vgg = VGG16(weights=self.weight, input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling=self.pooling, include_top=False) # 提取vgg16最后一层卷积特征 def vgg_extract_feat(self, image_bytes): try: img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1])) img = image.img_to_array(img) img = np.expand_dims(img, axis=0) img = preprocess_input_vgg(img) feat = self.model_vgg.predict(img) norm_feat = feat[0] / LA.norm(feat[0]) return norm_feat.tolist() except Exception as e: print(e, traceback.format_exc()) return list(np.zeros(shape=(512,))) # 创建VGGNet的实例 vgg_net = VGGNet() # 创建UDF extract_features_udf = udf(vgg_net.vgg_extract_feat, ArrayType(FloatType())) # # 使用UDF # df = df.withColumn("features", extract_features_udf("image")) class ExtractFeatures(Templates): def __init__(self, site_name='us', hdfs_pictures_path=f"/home/ffman/pictures/us/*/*/*/*/*/*/*.jpg"): super(ExtractFeatures, self).__init__() self.site_name = site_name self.db_save = f'extract_features' self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}") self.df_save = self.spark.sql(f"select 1+1;") self.df_features = self.spark.sql(f"select 1+1;") self.hdfs_pictures_path = hdfs_pictures_path self.partitions_by = ['site_name'] self.partitions_num = 10 def read_data(self): self.df_features = self.spark.sparkContext.binaryFiles(self.hdfs_pictures_path).toDF(["path", "image"]) # self.df_features.show(10) print("self.df_features.count():", self.df_features.count()) def handle_data(self): # 使用UDF self.df_features = self.df_features.withColumn("features", extract_features_udf("image")).cache() # self.df_features.show(10) def run(self): self.read_data() self.handle_data() self.df_save = self.df_features self.df_save = self.df_save.withColumn('site_name', lit(self.site_name)) self.save_data() if __name__ == '__main__': handle_obj = ExtractFeatures() handle_obj.run()