search.py 4.74 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
# import os
# import re
# import sys
#
# import numpy as np
# import pandas as pd
#
# sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
# from pyspark.storagelevel import StorageLevel
# from utils.templates import Templates
# # from ..utils.templates import Templates
# import faiss
# import numpy as np
# from pyspark.sql.functions import pandas_udf, PandasUDFType
# import pandas as pd
#
#
# # 定义一个Pandas UDF,该UDF在每个分区上加载索引并进行查询
# @pandas_udf('int', PandasUDFType.SCALAR)
# def find_nearest_neighbors(series):
#     # 加载索引
#     index = faiss.read_index("/home/ffman/tmp/my_index.faiss")
#     # 查询最近的5个邻居
#     _, I = index.search(np.array(series.tolist()).astype('float32'), 5)
#     return pd.Series(I[:, 0])  # 返回最近的邻居的索引
#
#
# class Search(Templates):
#
#     def __init__(self, site_name='us'):
#         super(Search, self).__init__()
#         self.site_name = site_name
#         self.db_save = f'image_search'
#         self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
#         self.df_save = self.spark.sql(f"select 1+1;")
#
#     def read_data(self):
#         self.df_save = self.spark.read.parquet("hdfs://hadoop5:8020/home/ffman/faiss/embeddings/folder")
#         self.df_save.show(20)
#
#         # 将嵌入向量从数据中提取出来,注意这会将所有的数据加载到内存中,如果数据太大可能会出现内存问题
#         data = self.df_save.collect()
#         embeddings = np.array([row['embedding'] for row in data])
#
#         # 创建Faiss索引
#         print("开始创建索引")
#         index = faiss.IndexFlatL2(512)
#         index.add(embeddings)
#         faiss.write_index(index, "/home/ffman/tmp/my_index.faiss")
#         print("索引创建完成+存储hdfs完成")
#         # 在驱动程序上构建索引并保存到磁盘
#         # embeddings = np.random.rand(1000, 512).astype('float32')  # 假设你的嵌入向量
#         # index = faiss.IndexFlatL2(embeddings.shape[1])
#         # index.add(embeddings)
#
#         # 假设query是你要查询的嵌入向量
#         query = np.random.rand(512).astype('float32')
#
#         # 查找最近的5个邻居
#         D, I = index.search(query.reshape(1, -1), 5)
#
#         # 打印结果
#         print("Distances: ", D)
#         print("Indices: ", I)
#
#     def handle_data(self):
#         # 在每个分区上进行查询
#         df = self.df_save.withColumn('nearest_neighbor', find_nearest_neighbors(self.df_save['embedding']))
#         df.show(20)
#         quit()
#
#
# if __name__ == '__main__':
#     site_name = sys.argv[1]  # 参数1:站点
#     handle_obj = Search()
#     handle_obj.run()


import sys
import os
import faiss
import numpy as np
from pyspark.sql.functions import pandas_udf, PandasUDFType
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates


class Search(Templates):
    def __init__(self, site_name='us'):
        super(Search, self).__init__()
        self.site_name = site_name
        self.db_save = f'image_search'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.data_path = "hdfs://hadoop5:8020/home/ffman/faiss/embeddings/folder"
        self.index_path = "/home/ffman/tmp/my_index.faiss"

    @staticmethod
    @pandas_udf(IntegerType(), PandasUDFType.SCALAR)
    def find_nearest_neighbors(series):
        # 加载索引
        index = faiss.read_index("/home/ffman/tmp/my_index.faiss")
        # 查询最近的5个邻居
        _, I = index.search(np.array(series.tolist()).astype('float32'), 5)
        return pd.Series(I[:, 0])  # 返回最近的邻居的索引

    def load_data_and_create_index(self):
        df = self.spark.read.parquet(self.data_path)
        data = df.collect()
        embeddings = np.array([row['embedding'] for row in data])

        # 创建Faiss索引并保存
        index = faiss.IndexFlatL2(512)
        index.add(embeddings)
        faiss.write_index(index, self.index_path)

    def handle_data(self):
        df = self.spark.read.parquet(self.data_path)
        df = df.withColumn('nearest_neighbor', self.find_nearest_neighbors(df['embedding']))
        df.show(20)


if __name__ == '__main__':
    # 创建Spark会话
    # spark = SparkSession.builder \
    #     .appName('example') \
    #     .getOrCreate()

    # 创建搜索对象
    site_name = sys.argv[1]  # 参数1:站点
    search = Search(site_name=site_name)

    # 加载数据并创建索引
    search.load_data_and_create_index()

    # 查询最近邻
    search.handle_data()