autofaiss_use.py 4.37 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
import os
import sys
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from pyspark.storagelevel import StorageLevel
from utils.templates import Templates
# from ..utils.templates import Templates
from autofaiss import build_index
from pyspark.sql import SparkSession  # pylint: disable=import-outside-toplevel

from pyspark import SparkConf, SparkContext


class Create_index(Templates):

    def __init__(self, site_name='us'):
        super(Create_index, self).__init__()
        self.site_name = site_name
        self.db_save = f'img_search'
        self.spark = self.create_spark_object(
            app_name=f"{self.db_save}: {self.site_name}")
        # self.spark = self.create_spark()
        self.df_st_detail = self.spark.sql(f"select 1+1;")

    def create_spark(self):
        # spark = SparkSession.builder. \
        #     appName(f"{self.db_save}: self.site_name"). \
        #     config("spark.sql.warehouse.dir", f"hdfs://hadoop5:8020/home/big_data_selection"). \
        #     config("spark.metastore.uris", "thrift://hadoop6:9083"). \
        #     config("spark.network.timeout", 1000000). \
        #     config("spark.sql.parquet.compression.codec", "lzo"). \
        #     config("spark.driver.maxResultSize", "50g"). \
        #     config("spark.sql.autoBroadcastJoinThreshold", -1). \
        #     enableHiveSupport(). \
        #     getOrCreate()
        # # 添加环境
        # # spark.sparkContext.addPyFile(path="hdfs://hadoop5:8020/lib/yswg_utils.zip")
        #
        # spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
        # spark.sql('''set mapred.output.compress=true''')
        # spark.sql('''set hive.exec.compress.output=true''')
        # spark.sql('''set mapred.output.compression.codec=com.hadoop.compression.lzo.LzopCodec''')
        # spark.sql(f"use big_data_selection;")
        os.environ['PYSPARK_PYTHON'] = "/root/autofaiss.pex"
        spark = (
            SparkSession.builder
                .config("spark.submit.deployMode", "client") \
                # .config("spark.executorEnv.PEX_ROOT", "./.pex")
                # .config("spark.executor.cores", "16")
                # .config("spark.cores.max", "48") # you can reduce this number if you want to use only some cores ; if you're using yarn the option name is different, check spark doc
                .config("spark.task.cpus", "4")
                .config("spark.driver.port", "4678")
                .config("spark.driver.blockManager.port", "6678")
                .config("spark.driver.host", "192.168.200.210")
                .config("spark.driver.bindAddress", "192.168.200.210")
                .config("spark.executor.memory",
                        "18G")  # make sure to increase this if you're using more cores per executor
                .config("spark.executor.memoryOverhead", "8G")
                .config("spark.task.maxFailures", "100")
                .master("yarn")  # this should point to your master node, if using the tunnelling version, keep this to localhost
                # .master("spark://192.168.200.210:7077")  # this should point to your master node, if using the tunnelling version, keep this to localhost
                # conf.setMaster("yarn")
                .appName("spark-stats")
                .getOrCreate()
        )
        # spark = SparkSession \
        #     .builder \
        #     .appName("Python Spark SQL basic example") \
        #     .config("spark.submit.deployMode", "client") \
        #     .config("spark.executor.memory", "2g") \
        #     .config("spark.driver.memory", "2g") \
        #     .master("yarn") \
        #     .getOrCreate()
        return spark

    def create_index(self):
        # self.spark.sql("use ffman;")
        index, index_infos = build_index(
            embeddings="hdfs://hadoop5:8020/home/ffman/embeddings/folder",
            distributed="pyspark",
            file_format="parquet",
            max_index_memory_usage="16G",
            current_memory_available="24G",
            temporary_indices_folder="hdfs://hadoop5:8020/home/ffman/tmp/distributed_autofaiss_indices",
            index_path="hdfs://hadoop5:8020/home/ffman/index/knn.index",
            index_infos_path="hdfs://hadoop5:8020/home/ffman/index/infos.json"
        )
        print("index, index_infos:", index, index_infos)
        return index, index_infos


if __name__ == '__main__':
    handle_obj = Create_index()
    handle_obj.create_index()