create_parquet.py 4.37 KB
import os
import sys

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from sqlalchemy import create_engine
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from pyspark.storagelevel import StorageLevel
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql.window import Window
from pyspark.sql import functions as F


class CreateParquet(Templates):

    def __init__(self):
        super(CreateParquet, self).__init__()
        self.engine = self.mysql_conn()
        self.df_features = pd.DataFrame()
        self.db_save = 'create_parquet'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}")

    @staticmethod
    def mysql_conn():
        Mysql_arguments = {
            'user': 'adv_yswg',
            'password': 'HCL1zcUgQesaaXNLbL37O5KhpSAy0c',
            'host': 'rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com',
            'port': 3306,
            'database': 'selection',
            'charset': 'utf8mb4',
        }

        def get_country_engine(site_name="us"):
            if site_name == 'us':
                db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
            else:
                Mysql_arguments["database"] = f"selection_{site_name}"
                db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
            engine = create_engine(db_)  # , pool_recycle=3600
            return engine

        engine = get_country_engine()
        return engine

    def read_data(self):
        # sql = f"select id, img_vector as features from us_asin_extract_features;"
        # self.df_features = pd.read_sql(sql, con=self.engine)

        sql = f"select id, img_vector as embedding from ods_asin_extract_features;"
        print("sql:", sql)
        self.df_features = self.spark.sql(sql).cache()
        # 添加索引列
        window = Window.orderBy("id")
        self.df_features = self.df_features.withColumn("index", F.row_number().over(window) - 1)  # 从0开始
        self.df_features.show(20, truncate=False)
        # self.df_features = self.df_features.cache()
        # 定义窗口按id排序

    def handle_data(self):
        # 假设你的DataFrame中有一个名为'id'的列,它的值是唯一的并且是从1开始的递增的整数。
        # 'block'列将每200000个'id'值放入一个块。
        # self.df_features = self.df_features.withColumn('block', F.floor(self.df_features['index'] / 200))

        # 用 'block' 列进行分区写入
        # self.df_features.write.partitionBy('block').parquet('/home/ffman/parquet/files')
        # self.df_features.write.mode('overwrite').parquet("/home/ffman/parquet/image.parquet")
        # df_count = self.df_features.count()
        # df = self.df_features.filter("index < 200000").select("embedding")
        # print("df.count():", df_count, df.count())
        # df = df.toPandas()
        # df.embedding = df.embedding.apply(lambda x: eval(x))
        # table = pa.Table.from_pandas(df)
        # pq.write_table(table, "/root/part1.parquet")

        # df_count = self.df_features.count()
        df_count = 35000000
        image_num = df_count

        # os.makedirs("my_parquet", exist_ok=True)
        step = 200000
        index_list = list(range(0, image_num, step))
        file_id_list = [f"{i:04}" for i in range(len(index_list))]
        print("index_list:", index_list)
        print("file_id_list:", file_id_list)
        for index, flie_id in zip(index_list, file_id_list):
            df_part = self.df_features.filter(f"index >= {index} and index < {index+step}").select("embedding")
            df_part = df_part.toPandas()
            table = pa.Table.from_pandas(df=df_part)
            file_name = f"/mnt/ffman/my_parquet/part_{flie_id}.parquet"
            pq.write_table(table, file_name)
            print("df_part.shape, index, file_name:", df_part.shape, index, file_name)

    def save_data(self):
        # # 设置分区大小为200000
        # self.spark.conf.set("spark.sql.files.maxRecordsPerFile", 200)
        #
        # # 将数据存储为Parquet格式
        # self.df_features.write.partitionBy("index").parquet("/home/ffman/parquet/files")
        pass

    def run(self):
        self.read_data()
        self.handle_data()
        self.save_data()


if __name__ == '__main__':
    handle_obj = CreateParquet()
    handle_obj.run()