pictures_dwd_id_index.py 6.81 KB
import ast
import os
import re
import sys
import time
import traceback

import pandas as pd
from pyspark.sql.types import ArrayType, FloatType

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing


class PicturesIdIndex(Templates):

    def __init__(self, site_name='us'):
        super(PicturesIdIndex, self).__init__()
        self.site_name = site_name
        self.engine_pg = TemplatesMysql().engine_pg
        self.db_save = f'pictures_dwd_id_index'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.df_features = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_save_local = self.spark.sql(f"select 1+1;")
        self.partitions_by = ['site_name']
        self.partitions_num = 1
        # mysql表获取的相关变量
        self.tn_pics_hdfs_index = f"{self.site_name}_pictures_hdfs_index"
        self.id = int()
        self.current_counts = int()
        self.all_counts = int()
        self.hdfs_path = str()
        self.hdfs_block_name = str()
        self.local_path = str()

    def read_data(self):
        # 读取mysql的hdfs_path路径信息
        with self.engine_pg.begin() as conn:
            sql_read = text(f"SELECT * FROM {self.tn_pics_hdfs_index} WHERE state=1 LIMIT 1 FOR UPDATE;")
            print("sql_read:", sql_read)
            result = conn.execute(sql_read)
            df = pd.DataFrame(result.fetchall())
            if df.shape[0]:
                df.columns = result.keys()
                self.id = list(df.id)[0] if list(df.id) else None
                self.current_counts = list(df.current_counts)[0] if list(df.current_counts) else None
                self.all_counts = list(df.all_counts)[0] if list(df.all_counts) else None
                self.hdfs_path = list(df.hdfs_path)[0] if list(df.hdfs_path) else None
                self.hdfs_block_name = re.findall("(part-\d+)-", self.hdfs_path)[0] if self.hdfs_path else None
                self.local_path = rf"/mnt/image_search/pictures_id_index/{self.site_name}/{self.hdfs_block_name}"

                print(f"df.shape:{df.shape}, self.id:{self.id}, self.current_counts:{self.current_counts}, self.all_counts:{self.all_counts}, self.hdfs_path:{self.hdfs_path}")
                if self.id:
                    os.system(f"hdfs dfs -rm -r /home/big_data_selection/dwd/pictures_dwd_id_index/site_name={self.site_name}/block={self.hdfs_block_name}")
                    sql_update = text(
                        f"UPDATE {self.tn_pics_hdfs_index} SET state=2 WHERE id={self.id};")
                    print("sql_update:", sql_update)
                    conn.execute(sql_update)
            else:
                quit()
        #  读取hdfs路径下的parquet文件
        self.df_features = self.spark.read.text(self.hdfs_path).cache()

    def save_data_local(self):
        print("当前存储到本地:", self.local_path)
        if os.path.exists(self.local_path):
            os.system(f"rm -rf {self.local_path}")
        os.makedirs(self.local_path)
        # Convert DataFrame to Arrow Table
        df = self.df_save_local.toPandas()
        table = pa.Table.from_pandas(df)
        # Save to Parquet
        pq.write_table(table, f"{self.local_path}/{self.hdfs_block_name}.parquet")

    def handle_data(self):
        # 创建一个新的 DataFrame,其中每个字段都是一个独立的列
        split_df = self.df_features.select(F.split(self.df_features['value'], '\t').alias('split_values'))
        # 假设你知道你的数据有三个字段
        # 你可以这样创建每个字段的独立列
        final_df = split_df.select(
            split_df['split_values'].getItem(0).alias('id'),
            split_df['split_values'].getItem(1).alias('asin'),
            split_df['split_values'].getItem(2).alias('embedding')
        )
        print("分块前分区数量:", final_df.rdd.getNumPartitions())

        # 从hdfs读取parquet文件,进行split切分的时候是字符串类型-->转换成数值类型
        final_df = final_df.withColumn('id', final_df['id'].cast('bigint'))  # 然后你可以安全地转换
        # 添加索引列
        final_df = final_df.withColumn("index", F.monotonically_increasing_id() + self.all_counts)
        final_df.show()

        # 此处更改数据类型在pictures_dim_features_slice已经做了 -- 由于读取的是lzo文件,而不是直接读表,因此还需要转换类型
        # 定义一个将字符串转换为列表的UDF
        str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
        # 对DataFrame中的列应用这个UDF
        final_df = final_df.withColumn("embedding", str_to_list_udf(final_df["embedding"]))

        # final_df.write.mode('overwrite').parquet("hdfs://hadoop5:8020/home/ffman/parquet")

        # final_df = final_df.withColumn("block", F.lit(self.hdfs_block_name))
        self.df_save = final_df.withColumn("site_name", F.lit(self.site_name))
        self.df_save_local = self.df_save.select("embedding")

    def update_state_after_save(self):
        with self.engine_pg.begin() as conn:
            sql_update = text(
                f"UPDATE {self.tn_pics_hdfs_index} SET state=3 WHERE state=2 and id={self.id};")
            print("sql_update:", sql_update)
            conn.execute(sql_update)

    def run(self):
        self.read_data()
        self.handle_data()
        self.save_data()
        # 存储到本地
        self.save_data_local()
        self.update_state_after_save()


def main():
    handle_obj = PicturesIdIndex()
    handle_obj.run()


if __name__ == "__main__":
    main()
    quit()


    processes = []
    for _ in range(5):  # 用于设定进程数量
        handle_obj = PicturesIdIndex()
        process = multiprocessing.Process(target=handle_obj.run)
        process.start()
        processes.append(process)

    # 等待所有进程完成
    for process in processes:
        process.join()

# if __name__ == '__main__':
#     # 设置进程数
#     num_processes = 4  # 设置为你需要的进程数
#     # 创建进程池对象
#     pool = Pool(processes=num_processes)
#     # 使用进程池中的进程执行任务
#     pool.apply(main)
#     # 关闭进程池
#     pool.close()
#     # 等待所有进程完成
#     pool.join()
# if __name__ == '__main__':
#     # 创建进程对象
#     process = Process(target=main)
#     # 启动进程
#     process.start()
#     # 等待进程结束
#     process.join()