1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# import os
# import re
# import sys
#
# import numpy as np
# import pandas as pd
#
# sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
# from pyspark.storagelevel import StorageLevel
# from utils.templates import Templates
# # from ..utils.templates import Templates
# import faiss
# import numpy as np
# from pyspark.sql.functions import pandas_udf, PandasUDFType
# import pandas as pd
#
#
# # 定义一个Pandas UDF,该UDF在每个分区上加载索引并进行查询
# @pandas_udf('int', PandasUDFType.SCALAR)
# def find_nearest_neighbors(series):
# # 加载索引
# index = faiss.read_index("/home/ffman/tmp/my_index.faiss")
# # 查询最近的5个邻居
# _, I = index.search(np.array(series.tolist()).astype('float32'), 5)
# return pd.Series(I[:, 0]) # 返回最近的邻居的索引
#
#
# class Search(Templates):
#
# def __init__(self, site_name='us'):
# super(Search, self).__init__()
# self.site_name = site_name
# self.db_save = f'image_search'
# self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
# self.df_save = self.spark.sql(f"select 1+1;")
#
# def read_data(self):
# self.df_save = self.spark.read.parquet("hdfs://hadoop5:8020/home/ffman/faiss/embeddings/folder")
# self.df_save.show(20)
#
# # 将嵌入向量从数据中提取出来,注意这会将所有的数据加载到内存中,如果数据太大可能会出现内存问题
# data = self.df_save.collect()
# embeddings = np.array([row['embedding'] for row in data])
#
# # 创建Faiss索引
# print("开始创建索引")
# index = faiss.IndexFlatL2(512)
# index.add(embeddings)
# faiss.write_index(index, "/home/ffman/tmp/my_index.faiss")
# print("索引创建完成+存储hdfs完成")
# # 在驱动程序上构建索引并保存到磁盘
# # embeddings = np.random.rand(1000, 512).astype('float32') # 假设你的嵌入向量
# # index = faiss.IndexFlatL2(embeddings.shape[1])
# # index.add(embeddings)
#
# # 假设query是你要查询的嵌入向量
# query = np.random.rand(512).astype('float32')
#
# # 查找最近的5个邻居
# D, I = index.search(query.reshape(1, -1), 5)
#
# # 打印结果
# print("Distances: ", D)
# print("Indices: ", I)
#
# def handle_data(self):
# # 在每个分区上进行查询
# df = self.df_save.withColumn('nearest_neighbor', find_nearest_neighbors(self.df_save['embedding']))
# df.show(20)
# quit()
#
#
# if __name__ == '__main__':
# site_name = sys.argv[1] # 参数1:站点
# handle_obj = Search()
# handle_obj.run()
import sys
import os
import faiss
import numpy as np
from pyspark.sql.functions import pandas_udf, PandasUDFType
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
class Search(Templates):
def __init__(self, site_name='us'):
super(Search, self).__init__()
self.site_name = site_name
self.db_save = f'image_search'
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.data_path = "hdfs://hadoop5:8020/home/ffman/faiss/embeddings/folder"
self.index_path = "/home/ffman/tmp/my_index.faiss"
@staticmethod
@pandas_udf(IntegerType(), PandasUDFType.SCALAR)
def find_nearest_neighbors(series):
# 加载索引
index = faiss.read_index("/home/ffman/tmp/my_index.faiss")
# 查询最近的5个邻居
_, I = index.search(np.array(series.tolist()).astype('float32'), 5)
return pd.Series(I[:, 0]) # 返回最近的邻居的索引
def load_data_and_create_index(self):
df = self.spark.read.parquet(self.data_path)
data = df.collect()
embeddings = np.array([row['embedding'] for row in data])
# 创建Faiss索引并保存
index = faiss.IndexFlatL2(512)
index.add(embeddings)
faiss.write_index(index, self.index_path)
def handle_data(self):
df = self.spark.read.parquet(self.data_path)
df = df.withColumn('nearest_neighbor', self.find_nearest_neighbors(df['embedding']))
df.show(20)
if __name__ == '__main__':
# 创建Spark会话
# spark = SparkSession.builder \
# .appName('example') \
# .getOrCreate()
# 创建搜索对象
site_name = sys.argv[1] # 参数1:站点
search = Search(site_name=site_name)
# 加载数据并创建索引
search.load_data_and_create_index()
# 查询最近邻
search.handle_data()