1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# import cv2
# import numpy as np
# from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
# from tensorflow.keras.models import Model
# from tensorflow.keras.preprocessing import image
# from pyspark.sql.functions import udf
# from pyspark.sql.types import StringType
# from pyspark import SparkFiles
#
# # Load ResNet50 model
# base_model = ResNet50(weights='imagenet')
# model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output)
#
#
# def extract_features(file_path):
# try:
# # Load image
# img = image.load_img(file_path, target_size=(224, 224))
#
# # Preprocess image
# x = image.img_to_array(img)
# x = np.expand_dims(x, axis=0)
# x = preprocess_input(x)
#
# # Extract features
# features = model.predict(x)
# features = features.flatten()
#
# return features.tolist()
#
# except Exception as e:
# return [str(e)]
#
#
# # Create UDF
# extract_features_udf = udf(extract_features, StringType())
#
# # Get list of images from HDFS
# images_rdd = sc.wholeTextFiles("hdfs://<your-hdfs-path>/*.jpg")
# images_df = images_rdd.toDF(["path", "image"])
#
# # Extract features
# features_df = images_df.withColumn("features", extract_features_udf(images_df["path"]))
#
# features_df.show()
import ast
import os
import sys
import traceback
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
import io
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from pyspark.sql.functions import udf, lit
from pyspark.sql.types import ArrayType, FloatType
class VGGNet:
def __init__(self):
# ...省略其它代码...
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling, include_top=False)
# 提取vgg16最后一层卷积特征
def vgg_extract_feat(self, image_bytes):
try:
img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
norm_feat = feat[0] / LA.norm(feat[0])
return norm_feat.tolist()
except Exception as e:
print(e, traceback.format_exc())
return list(np.zeros(shape=(512,)))
# 创建VGGNet的实例
vgg_net = VGGNet()
# 创建UDF
extract_features_udf = udf(vgg_net.vgg_extract_feat, ArrayType(FloatType()))
# # 使用UDF
# df = df.withColumn("features", extract_features_udf("image"))
class ExtractFeatures(Templates):
def __init__(self, site_name='us', hdfs_pictures_path=f"/home/ffman/pictures/us/*/*/*/*/*/*/*.jpg"):
super(ExtractFeatures, self).__init__()
self.site_name = site_name
self.db_save = f'extract_features'
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_save = self.spark.sql(f"select 1+1;")
self.df_features = self.spark.sql(f"select 1+1;")
self.hdfs_pictures_path = hdfs_pictures_path
self.partitions_by = ['site_name']
self.partitions_num = 10
def read_data(self):
self.df_features = self.spark.sparkContext.binaryFiles(self.hdfs_pictures_path).toDF(["path", "image"])
# self.df_features.show(10)
print("self.df_features.count():", self.df_features.count())
def handle_data(self):
# 使用UDF
self.df_features = self.df_features.withColumn("features", extract_features_udf("image")).cache()
# self.df_features.show(10)
def run(self):
self.read_data()
self.handle_data()
self.df_save = self.df_features
self.df_save = self.df_save.withColumn('site_name', lit(self.site_name))
self.save_data()
if __name__ == '__main__':
handle_obj = ExtractFeatures()
handle_obj.run()