import ast import logging import os import re import sys import threading import time import traceback import pandas as pd import redis from pyspark.sql.types import ArrayType, FloatType os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" sys.path.append(os.path.dirname(sys.path[0])) # 上级目录 from utils.templates_mysql import TemplatesMysql from utils.templates import Templates # from ..utils.templates import Templates from py4j.java_gateway import java_import from sqlalchemy import text from pyspark.sql import functions as F import pyarrow as pa import pyarrow.parquet as pq from multiprocessing import Process from multiprocessing import Pool import multiprocessing from utils.db_util import DbTypes, DBUtil class PicturesIdIndex(Templates): def __init__(self, site_name='us', asin_type=1, thread_num=10): super(PicturesIdIndex, self).__init__() self.site_name = site_name self.asin_type = asin_type self.thread_num = thread_num self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name) self.db_save = f'image_dwd_id_index' self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}") self.df_features = self.spark.sql(f"select 1+1;") self.df_save = self.spark.sql(f"select 1+1;") self.df_save_local = self.spark.sql(f"select 1+1;") self.partitions_by = ['site_name', 'asin_type'] self.partitions_num = 1 # mysql表获取的相关变量 self.tn_pics_hdfs_index = f"image_hdfs_index" self.id = int() self.current_counts = int() self.all_counts = int() self.hdfs_path = str() self.hdfs_block_name = str() self.local_path = str() self.local_name = self.db_save def read_data(self): while True: try: lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name) if lock_acquired: # 读取mysql的hdfs_path路径信息 with self.engine_srs.begin() as conn: sql_read = text(f"SELECT * FROM {self.tn_pics_hdfs_index} WHERE state=1 LIMIT 1;") print("sql_read:", sql_read) result = conn.execute(sql_read) df = pd.DataFrame(result.fetchall()) if df.shape[0]: df.columns = result.keys() self.id = list(df.id)[0] if list(df.id) else None self.current_counts = list(df.current_counts)[0] if list(df.current_counts) else None self.all_counts = list(df.all_counts)[0] if list(df.all_counts) else None self.hdfs_path = list(df.hdfs_path)[0] if list(df.hdfs_path) else None self.hdfs_block_name = re.findall("(part-\d+)-", self.hdfs_path)[0] if self.hdfs_path else None self.local_path = rf"/mnt/image_search/image_id_index/{self.site_name}/{self.hdfs_block_name}" print(f"df.shape:{df.shape}, self.id:{self.id}, self.current_counts:{self.current_counts}, self.all_counts:{self.all_counts}, self.hdfs_path:{self.hdfs_path}") if self.id: os.system(f"hdfs dfs -rm -r /home/big_data_selection/dwd/image_dwd_id_index/site_name={self.site_name}/block={self.hdfs_block_name}") sql_update = text( f"UPDATE {self.tn_pics_hdfs_index} SET state=2 WHERE id={self.id};") print("sql_update:", sql_update) conn.execute(sql_update) else: quit() # 读取hdfs路径下的parquet文件 self.df_features = self.spark.read.text(self.hdfs_path).cache() self.release_lock(lock_name=self.local_name, lock_value=lock_value) return df else: print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据") time.sleep(5) # 等待5s继续访问锁 continue except Exception as e: print(f"读取数据错误: {e}", traceback.format_exc()) time.sleep(5) self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name) self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023') continue def handle_data(self): # 创建一个新的 DataFrame,其中每个字段都是一个独立的列 split_df = self.df_features.select(F.split(self.df_features['value'], '\t').alias('split_values')) # 假设你知道你的数据有三个字段 # 你可以这样创建每个字段的独立列 final_df = split_df.select( split_df['split_values'].getItem(0).alias('id'), split_df['split_values'].getItem(1).alias('asin'), split_df['split_values'].getItem(2).alias('embedding') ) print("分块前分区数量:", final_df.rdd.getNumPartitions()) # 从hdfs读取parquet文件,进行split切分的时候是字符串类型-->转换成数值类型 final_df = final_df.withColumn('id', final_df['id'].cast('bigint')) # 然后你可以安全地转换 # 添加索引列 final_df = final_df.withColumn("index", F.monotonically_increasing_id() + self.all_counts) final_df.show() # 此处更改数据类型在pictures_dim_features_slice已经做了 -- 由于读取的是lzo文件,而不是直接读表,因此还需要转换类型 # 定义一个将字符串转换为列表的UDF str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType())) # 对DataFrame中的列应用这个UDF final_df = final_df.withColumn("embedding", str_to_list_udf(final_df["embedding"])) # final_df.write.mode('overwrite').parquet("hdfs://hadoop5:8020/home/ffman/parquet") # final_df = final_df.withColumn("block", F.lit(self.hdfs_block_name)) self.df_save = final_df.withColumn("site_name", F.lit(self.site_name)) self.df_save = self.df_save.withColumn("asin_type", F.lit(self.asin_type)) self.df_save.show(10) self.df_save_local = self.df_save.select("embedding") def save_data_local(self): print("当前存储到本地:", self.local_path) if os.path.exists(self.local_path): os.system(f"rm -rf {self.local_path}") os.makedirs(self.local_path) # Convert DataFrame to Arrow Table df = self.df_save_local.toPandas() table = pa.Table.from_pandas(df) # Save to Parquet pq.write_table(table, f"{self.local_path}/{self.hdfs_block_name}.parquet") def update_state_after_save(self): with self.engine_srs.begin() as conn: sql_update = text( f"UPDATE {self.tn_pics_hdfs_index} SET state=3 WHERE state=2 and id={self.id};") print("sql_update:", sql_update) conn.execute(sql_update) def run(self): while True: try: df = self.read_data() if df.shape[0]: self.handle_data() self.save_data() # 存储到本地 self.save_data_local() self.update_state_after_save() else: break except Exception as e: print(f"error: {e}", traceback.format_exc()) self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name) time.sleep(10) def run_thread(self): thread_list = [] for thread_id in range(self.thread_num): thread = threading.Thread(target=self.run) thread_list.append(thread) thread.start() for thread in thread_list: thread.join() logging.info("所有线程处理完成") # def main(): # handle_obj = PicturesIdIndex() # handle_obj.run() if __name__ == "__main__": handle_obj = PicturesIdIndex(thread_num=5) handle_obj.run() # handle_obj.run_thread() quit() processes = [] for _ in range(5): # 用于设定进程数量 handle_obj = PicturesIdIndex() process = multiprocessing.Process(target=handle_obj.run) process.start() processes.append(process) # 等待所有进程完成 for process in processes: process.join() # if __name__ == '__main__': # # 设置进程数 # num_processes = 4 # 设置为你需要的进程数 # # 创建进程池对象 # pool = Pool(processes=num_processes) # # 使用进程池中的进程执行任务 # pool.apply(main) # # 关闭进程池 # pool.close() # # 等待所有进程完成 # pool.join() # if __name__ == '__main__': # # 创建进程对象 # process = Process(target=main) # # 启动进程 # process.start() # # 等待进程结束 # process.join()