# import cv2
# import numpy as np
# from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
# from tensorflow.keras.models import Model
# from tensorflow.keras.preprocessing import image
# from pyspark.sql.functions import udf
# from pyspark.sql.types import StringType
# from pyspark import SparkFiles
#
# # Load ResNet50 model
# base_model = ResNet50(weights='imagenet')
# model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output)
#
#
# def extract_features(file_path):
#     try:
#         # Load image
#         img = image.load_img(file_path, target_size=(224, 224))
#
#         # Preprocess image
#         x = image.img_to_array(img)
#         x = np.expand_dims(x, axis=0)
#         x = preprocess_input(x)
#
#         # Extract features
#         features = model.predict(x)
#         features = features.flatten()
#
#         return features.tolist()
#
#     except Exception as e:
#         return [str(e)]
#
#
# # Create UDF
# extract_features_udf = udf(extract_features, StringType())
#
# # Get list of images from HDFS
# images_rdd = sc.wholeTextFiles("hdfs://<your-hdfs-path>/*.jpg")
# images_df = images_rdd.toDF(["path", "image"])
#
# # Extract features
# features_df = images_df.withColumn("features", extract_features_udf(images_df["path"]))
#
# features_df.show()


import ast
import os
import sys
import traceback

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates

import io
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg


from pyspark.sql.functions import udf, lit
from pyspark.sql.types import ArrayType, FloatType


class VGGNet:
    def __init__(self):
        # ...省略其它代码...
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model_vgg = VGG16(weights=self.weight,
                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
                               pooling=self.pooling, include_top=False)

    # 提取vgg16最后一层卷积特征
    def vgg_extract_feat(self, image_bytes):
        try:
            img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
            img = image.img_to_array(img)
            img = np.expand_dims(img, axis=0)
            img = preprocess_input_vgg(img)
            feat = self.model_vgg.predict(img)
            norm_feat = feat[0] / LA.norm(feat[0])
            return norm_feat.tolist()
        except Exception as e:
            print(e, traceback.format_exc())
            return list(np.zeros(shape=(512,)))


# 创建VGGNet的实例
vgg_net = VGGNet()

# 创建UDF
extract_features_udf = udf(vgg_net.vgg_extract_feat, ArrayType(FloatType()))

# # 使用UDF
# df = df.withColumn("features", extract_features_udf("image"))


class ExtractFeatures(Templates):

    def __init__(self, site_name='us', hdfs_pictures_path=f"/home/ffman/pictures/us/*/*/*/*/*/*/*.jpg"):
        super(ExtractFeatures, self).__init__()
        self.site_name = site_name
        self.db_save = f'extract_features'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_features = self.spark.sql(f"select 1+1;")
        self.hdfs_pictures_path = hdfs_pictures_path
        self.partitions_by = ['site_name']
        self.partitions_num = 10

    def read_data(self):
        self.df_features = self.spark.sparkContext.binaryFiles(self.hdfs_pictures_path).toDF(["path", "image"])
        # self.df_features.show(10)
        print("self.df_features.count():", self.df_features.count())

    def handle_data(self):
        # 使用UDF
        self.df_features = self.df_features.withColumn("features", extract_features_udf("image")).cache()
        # self.df_features.show(10)

    def run(self):
        self.read_data()
        self.handle_data()
        self.df_save = self.df_features
        self.df_save = self.df_save.withColumn('site_name', lit(self.site_name))
        self.save_data()


if __name__ == '__main__':
    handle_obj = ExtractFeatures()
    handle_obj.run()