create_parquet.py 4.37 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
import os
import sys

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from sqlalchemy import create_engine
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from pyspark.storagelevel import StorageLevel
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql.window import Window
from pyspark.sql import functions as F


class CreateParquet(Templates):

    def __init__(self):
        super(CreateParquet, self).__init__()
        self.engine = self.mysql_conn()
        self.df_features = pd.DataFrame()
        self.db_save = 'create_parquet'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}")

    @staticmethod
    def mysql_conn():
        Mysql_arguments = {
            'user': 'adv_yswg',
            'password': 'HCL1zcUgQesaaXNLbL37O5KhpSAy0c',
            'host': 'rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com',
            'port': 3306,
            'database': 'selection',
            'charset': 'utf8mb4',
        }

        def get_country_engine(site_name="us"):
            if site_name == 'us':
                db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
            else:
                Mysql_arguments["database"] = f"selection_{site_name}"
                db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
            engine = create_engine(db_)  # , pool_recycle=3600
            return engine

        engine = get_country_engine()
        return engine

    def read_data(self):
        # sql = f"select id, img_vector as features from us_asin_extract_features;"
        # self.df_features = pd.read_sql(sql, con=self.engine)

        sql = f"select id, img_vector as embedding from ods_asin_extract_features;"
        print("sql:", sql)
        self.df_features = self.spark.sql(sql).cache()
        # 添加索引列
        window = Window.orderBy("id")
        self.df_features = self.df_features.withColumn("index", F.row_number().over(window) - 1)  # 从0开始
        self.df_features.show(20, truncate=False)
        # self.df_features = self.df_features.cache()
        # 定义窗口按id排序

    def handle_data(self):
        # 假设你的DataFrame中有一个名为'id'的列,它的值是唯一的并且是从1开始的递增的整数。
        # 'block'列将每200000个'id'值放入一个块。
        # self.df_features = self.df_features.withColumn('block', F.floor(self.df_features['index'] / 200))

        # 用 'block' 列进行分区写入
        # self.df_features.write.partitionBy('block').parquet('/home/ffman/parquet/files')
        # self.df_features.write.mode('overwrite').parquet("/home/ffman/parquet/image.parquet")
        # df_count = self.df_features.count()
        # df = self.df_features.filter("index < 200000").select("embedding")
        # print("df.count():", df_count, df.count())
        # df = df.toPandas()
        # df.embedding = df.embedding.apply(lambda x: eval(x))
        # table = pa.Table.from_pandas(df)
        # pq.write_table(table, "/root/part1.parquet")

        # df_count = self.df_features.count()
        df_count = 35000000
        image_num = df_count

        # os.makedirs("my_parquet", exist_ok=True)
        step = 200000
        index_list = list(range(0, image_num, step))
        file_id_list = [f"{i:04}" for i in range(len(index_list))]
        print("index_list:", index_list)
        print("file_id_list:", file_id_list)
        for index, flie_id in zip(index_list, file_id_list):
            df_part = self.df_features.filter(f"index >= {index} and index < {index+step}").select("embedding")
            df_part = df_part.toPandas()
            table = pa.Table.from_pandas(df=df_part)
            file_name = f"/mnt/ffman/my_parquet/part_{flie_id}.parquet"
            pq.write_table(table, file_name)
            print("df_part.shape, index, file_name:", df_part.shape, index, file_name)

    def save_data(self):
        # # 设置分区大小为200000
        # self.spark.conf.set("spark.sql.files.maxRecordsPerFile", 200)
        #
        # # 将数据存储为Parquet格式
        # self.df_features.write.partitionBy("index").parquet("/home/ffman/parquet/files")
        pass

    def run(self):
        self.read_data()
        self.handle_data()
        self.save_data()


if __name__ == '__main__':
    handle_obj = CreateParquet()
    handle_obj.run()