faiss_search.py 6.61 KB
# import os
# import re
# import sys
# import faiss
#
# import pandas as pd
# import numpy as np
# sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
# from pyspark.storagelevel import StorageLevel
# from utils.templates import Templates
# from ..utils.templates import Templates
#
# class FaissSearch(Templates):
#
#     def __init__(self, site_name='us'):
#         super(FaissSearch, self).__init__()
#         self.site_name = site_name
#         self.db_save = f'faiss_search'
#         self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
#         self.df_features = self.spark.sql(f"select 1+1;")
#         self.df_save = self.spark.sql(f"select 1+1;")
#         self.query_vectors = np.random.rand(512)
#
#     def read_data(self):
#         # 创建一个包含两列的 DataFrame
#         df_features = pd.DataFrame({
#             'id': range(1, 1001),  # 假设我们有 1000 张图片,每张图片的 ID 从 1 到 100
#             'features': [np.random.rand(512) for _ in range(100)]  # 每个特征向量是一个长度为 512 的随机向量
#         })
#         self.df_features = self.spark.createDataFrame(df_features)
#
#     def handle_data(self):
#         # 2. 对 DataFrame 进行分区
#         # 添加一个新列,表示分区号
#         partition_size = 1000000  # 每个分区的大小
#         self.df_features = self.df_features.withColumn('partition_id', (self.df_features['id'] / partition_size).cast('integer'))
#
#         # 使用 repartition 方法对 DataFrame 进行分区
#         self.df_features = self.df_features.repartition('partition_id')
#
#         # 3. 在每个分区上创建 FAISS 索引
#         def create_index(partition):
#             # 每个分区创建一个索引
#             d = 512  # 特征向量的维度
#             index = faiss.IndexFlatL2(d)
#
#             # 将分区中的向量添加到索引
#             for row in partition:
#                 index.add(np.array(row['features'], dtype=np.float32).reshape(-1, d))
#
#             return [(index, row['id']) for row in partition]
#
#         rdd = self.df_features.rdd.mapPartitions(create_index)
#
#         # 4. 对查询向量进行广播,然后在每个分区上执行查询
#         # query_vectors = ...  # 这是一个 numpy 数组,包含查询向量
#         query_vectors_broadcast = self.spark.sparkContext.broadcast(self.query_vectors)
#
#         def search(partition):
#             # 每个分区执行查询
#             for index, _ in partition:
#                 D, I = index.search(query_vectors_broadcast.value, k)
#                 yield (D, I)
#
#         results = rdd.mapPartitions(search).collect()
#
#         # 5. 合并结果
#         D, I = zip(*results)
#         D = np.concatenate(D)
#         I = np.concatenate(I)
#
#         # 对结果进行排序并取前 k 个结果
#         indices = np.argsort(D, axis=1)[:, :k]
#         D = np.take_along_axis(D, indices, axis=1)
#         I = np.take_along_axis(I, indices, axis=1)

import pandas as pd
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark import pandas as pdf
from pyspark.sql import Row
import faiss

# 创建一个包含两列的 DataFrame
# df_features = pd.DataFrame({
#     'id': range(1, 101),  # 假设我们有 100 张图片,每张图片的 ID 从 1 到 100
#     'features': [np.random.rand(512).tolist() for _ in range(100)]  # 每个特征向量是一个长度为 512 的随机向量
# })
# data = []
# for i in range(1, 101):
#     data.append({
#         'id': i,
#         'features': np.random.rand(512).tolist()
#     })
#
# df_features = pdf.DataFrame(data)
import sys
import os
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates

spark = Templates().create_spark_object(app_name=f"faiss")
sql = "select * from ods_brand_analytics limit 100"
df = spark.sql(sql)
df.show(10)

Mysql_arguments = {
    'user': 'adv_yswg',
    'password': 'HCL1zcUgQesaaXNLbL37O5KhpSAy0c',
    'host': 'rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com',
    'port': 3306,
    'database': 'selection',
    'charset': 'utf8mb4',
}
def get_country_engine(site_name="us"):
    if site_name == 'us':
        db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
    else:
        Mysql_arguments["database"] = f"selection_{site_name}"
        db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
    engine = create_engine(db_)  # , pool_recycle=3600
    return engine
engine = get_country_engine()
sql = f"select id, img_vector as features from us_asin_extract_features limit 100;"
df_features = pd.read_sql(sql, con=engine)


# 1. 创建一个 SparkSession
# spark = SparkSession.builder.getOrCreate()

# 将 Pandas DataFrame 转换为 PySpark DataFrame
df_spark = spark.createDataFrame(df_features)


# 2. 对 DataFrame 进行分区
# 添加一个新列,表示分区号
partition_size = 100  # 每个分区的大小
df_spark = df_spark.withColumn('partition_id', (df_spark['id'] / partition_size).cast('integer'))

# 使用 repartition 方法对 DataFrame 进行分区
df_spark = df_spark.repartition('partition_id')


# 3. 在每个分区上创建 FAISS 索引
def create_index(partition):
    # 每个分区创建一个索引
    d = 512  # 特征向量的维度
    nlist = 100
    k = 4
    quantizer = faiss.IndexFlatL2(d)
    index = faiss.IndexIVFFlat(quantizer, d, nlist)

    # 将分区中的向量添加到索引
    data = [np.array(row.features, dtype=np.float32) for row in partition]
    ids = [row.id for row in partition]
    index.train(np.array(data))
    index.add_with_ids(np.array(data), np.array(ids, dtype=np.int64))

    return [(index, row.id) for row in partition]


rdd = df_spark.rdd.mapPartitions(create_index)

# 4. 对查询向量进行广播,然后在每个分区上执行查询
query_vectors = np.random.rand(1, 512).astype('float32')  # 这是一个 numpy 数组,包含查询向量
query_vectors_broadcast = spark.sparkContext.broadcast(query_vectors)


def search(partition):
    # 每个分区执行查询
    for index, _ in partition:
        D, I = index.search(query_vectors_broadcast.value, 4)
        yield (D, I)


results = rdd.mapPartitions(search).collect()

# 5. 合并结果
D, I = zip(*results)
D = np.concatenate(D)
I = np.concatenate(I)

# 对结果进行排序并取前 k 个结果
k = 4
indices = np.argsort(D, axis=1)[:, :k]
D = np.take_along_axis(D, indices, axis=1)
I = np.take_along_axis(I, indices, axis=1)

print("查询向量的最近邻的id:", I)
print("查询向量的最近邻的距离:", D)