auto_index.py 5.28 KB
import os
import sys
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from pyspark.storagelevel import StorageLevel
# from ..utils.templates import Templates
from autofaiss import build_index
from pyspark.sql import SparkSession  # pylint: disable=import-outside-toplevel

from pyspark import SparkConf, SparkContext


class Create_index():

    def __init__(self, site_name='us'):
        super(Create_index, self).__init__()
        self.site_name = site_name
        self.db_save = f'img_search'
        # self.spark = self.create_spark_object(
        #     app_name=f"{self.db_save}: {self.site_name}")
        self.spark = self.create_spark()
        self.df_st_detail = self.spark.sql(f"select 1+1;")

    def create_spark(self):
        # spark = SparkSession.builder. \
        #     appName(f"{self.db_save}: self.site_name"). \
        #     config("spark.sql.warehouse.dir", f"hdfs://hadoop5:8020/home/big_data_selection"). \
        #     config("spark.metastore.uris", "thrift://hadoop6:9083"). \
        #     config("spark.network.timeout", 1000000). \
        #     config("spark.sql.parquet.compression.codec", "lzo"). \
        #     config("spark.driver.maxResultSize", "50g"). \
        #     config("spark.sql.autoBroadcastJoinThreshold", -1). \
        #     enableHiveSupport(). \
        #     getOrCreate()
        # # 添加环境
        # # spark.sparkContext.addPyFile(path="hdfs://hadoop5:8020/lib/yswg_utils.zip")
        #
        # spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        # spark.sql('''set mapred.output.compress=true''')
        # spark.sql('''set hive.exec.compress.output=true''')
        # spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        # spark.sql(f"use big_data_selection;")
        os.environ['PYSPARK_PYTHON'] = "/root/autofaiss.pex"
        # spark = (
        #     SparkSession.builder
        #         .config("spark.submit.deployMode", "client") \
        #         .config("spark.executorEnv.PEX_ROOT", "./.pex")
        #         .config("spark.executor.cores", "4")
        #         # .config("spark.cores.max", "48") # you can reduce this number if you want to use only some cores ; if you're using yarn the option name is different, check spark doc
        #         .config("spark.task.cpus", "1")
        #         .config("spark.driver.port", "4678")
        #         .config("spark.driver.blockManager.port", "6678")
        #         .config("spark.driver.host", "192.168.200.210")
        #         .config("spark.driver.bindAddress", "192.168.200.210")
        #         .config("spark.executor.memory",
        #                 "18G")  # make sure to increase this if you're using more cores per executor
        #         .config("spark.executor.memoryOverhead", "12G")
        #         .config("spark.task.maxFailures", "100")
        #         .master("yarn")  # this should point to your master node, if using the tunnelling version, keep this to localhost
        #         # .master("spark://192.168.200.210:7077")  # this should point to your master node, if using the tunnelling version, keep this to localhost
        #         # conf.setMaster("yarn")
        #         .appName("spark-stats")
        #         .getOrCreate()
        # )

        spark = (
            SparkSession.builder
                .config("spark.submit.deployMode", "client")
                .config("spark.executorEnv.PEX_ROOT", "./.pex")
                .config("spark.executor.cores", "4")
                .config("spark.task.cpus", "1")
                .config("spark.driver.port", "4678")
                .config("spark.driver.blockManager.port", "6678")
                .config("spark.driver.host", "192.168.200.210")
                .config("spark.driver.bindAddress", "192.168.200.210")
                .config("spark.executor.memory", "8G")  # reduce the executor memory
                .config("spark.executor.memoryOverhead", "2G")  # reduce the overhead memory
                .config("spark.task.maxFailures", "100")
                .master("yarn")
                .appName("spark-stats")
                .getOrCreate()
        )

        # spark = SparkSession \
        #     .builder \
        #     .appName("Python Spark SQL basic example") \
        #     .config("spark.submit.deployMode", "client") \
        #     .config("spark.executor.memory", "2g") \
        #     .config("spark.driver.memory", "2g") \
        #     .master("yarn") \
        #     .getOrCreate()
        return spark

    def create_index(self):
        # self.spark.sql("use ffman;")
        index, index_infos = build_index(
            embeddings="hdfs://hadoop5:8020/home/ffman/embeddings/folder",
            distributed="pyspark",
            file_format="parquet",
            max_index_memory_usage="16G",
            current_memory_available="24G",
            temporary_indices_folder="hdfs://hadoop5:8020/home/ffman/tmp/distributed_autofaiss_indices",
            index_path="hdfs://hadoop5:8020/home/ffman/index/knn.index",
            index_infos_path="hdfs://hadoop5:8020/home/ffman/index/infos.json",
            nb_indices_to_keep=10  # 建立10个索引
        )
        print("index, index_infos:", index, index_infos)
        return index, index_infos


if __name__ == '__main__':
    handle_obj = Create_index()
    handle_obj.create_index()