1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from sqlalchemy import create_engine
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from pyspark.storagelevel import StorageLevel
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql.window import Window
from pyspark.sql import functions as F
class CreateParquet(Templates):
def __init__(self):
super(CreateParquet, self).__init__()
self.engine = self.mysql_conn()
self.df_features = pd.DataFrame()
self.db_save = 'create_parquet'
self.spark = self.create_spark_object(app_name=f"{self.db_save}")
@staticmethod
def mysql_conn():
Mysql_arguments = {
'user': 'adv_yswg',
'password': 'HCL1zcUgQesaaXNLbL37O5KhpSAy0c',
'host': 'rm-wz9yg9bsb2zf01ea4yo.mysql.rds.aliyuncs.com',
'port': 3306,
'database': 'selection',
'charset': 'utf8mb4',
}
def get_country_engine(site_name="us"):
if site_name == 'us':
db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
else:
Mysql_arguments["database"] = f"selection_{site_name}"
db_ = 'mysql+pymysql://{}:{}@{}:{}/{}?charset={}'.format(*Mysql_arguments.values())
engine = create_engine(db_) # , pool_recycle=3600
return engine
engine = get_country_engine()
return engine
def read_data(self):
# sql = f"select id, img_vector as features from us_asin_extract_features;"
# self.df_features = pd.read_sql(sql, con=self.engine)
sql = f"select id, img_vector as embedding from ods_asin_extract_features;"
print("sql:", sql)
self.df_features = self.spark.sql(sql).cache()
# 添加索引列
window = Window.orderBy("id")
self.df_features = self.df_features.withColumn("index", F.row_number().over(window) - 1) # 从0开始
self.df_features.show(20, truncate=False)
# self.df_features = self.df_features.cache()
# 定义窗口按id排序
def handle_data(self):
# 假设你的DataFrame中有一个名为'id'的列,它的值是唯一的并且是从1开始的递增的整数。
# 'block'列将每200000个'id'值放入一个块。
# self.df_features = self.df_features.withColumn('block', F.floor(self.df_features['index'] / 200))
# 用 'block' 列进行分区写入
# self.df_features.write.partitionBy('block').parquet('/home/ffman/parquet/files')
# self.df_features.write.mode('overwrite').parquet("/home/ffman/parquet/image.parquet")
# df_count = self.df_features.count()
# df = self.df_features.filter("index < 200000").select("embedding")
# print("df.count():", df_count, df.count())
# df = df.toPandas()
# df.embedding = df.embedding.apply(lambda x: eval(x))
# table = pa.Table.from_pandas(df)
# pq.write_table(table, "/root/part1.parquet")
# df_count = self.df_features.count()
df_count = 35000000
image_num = df_count
# os.makedirs("my_parquet", exist_ok=True)
step = 200000
index_list = list(range(0, image_num, step))
file_id_list = [f"{i:04}" for i in range(len(index_list))]
print("index_list:", index_list)
print("file_id_list:", file_id_list)
for index, flie_id in zip(index_list, file_id_list):
df_part = self.df_features.filter(f"index >= {index} and index < {index+step}").select("embedding")
df_part = df_part.toPandas()
table = pa.Table.from_pandas(df=df_part)
file_name = f"/mnt/ffman/my_parquet/part_{flie_id}.parquet"
pq.write_table(table, file_name)
print("df_part.shape, index, file_name:", df_part.shape, index, file_name)
def save_data(self):
# # 设置分区大小为200000
# self.spark.conf.set("spark.sql.files.maxRecordsPerFile", 200)
#
# # 将数据存储为Parquet格式
# self.df_features.write.partitionBy("index").parquet("/home/ffman/parquet/files")
pass
def run(self):
self.read_data()
self.handle_data()
self.save_data()
if __name__ == '__main__':
handle_obj = CreateParquet()
handle_obj.run()