img_dwd_id_index.py 9.11 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
import ast
import logging
import os
import re
import sys
import threading
import time
import traceback

import pandas as pd
import redis
from pyspark.sql.types import ArrayType, FloatType

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing
from utils.db_util import DbTypes, DBUtil


class PicturesIdIndex(Templates):

    def __init__(self, site_name='us', img_type=1, thread_num=10):
        super(PicturesIdIndex, self).__init__()
        self.site_name = site_name
        self.img_type = img_type
        self.thread_num = thread_num
        self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
        self.db_save = f'img_dwd_id_index'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.df_features = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_save_local = self.spark.sql(f"select 1+1;")
        self.partitions_by = ['site_name', 'img_type']
        self.partitions_num = 1
        # mysql表获取的相关变量
        self.tn_pics_hdfs_index = f"img_hdfs_index"
        self.id = int()
        self.current_counts = int()
        self.all_counts = int()
        self.hdfs_path = str()
        self.hdfs_block_name = str()
        self.local_path = str()
        self.local_name = self.db_save

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
                if lock_acquired:
                    # 读取mysql的hdfs_path路径信息
                    with self.engine_srs.begin() as conn:
                        sql_read = text(f"SELECT * FROM {self.tn_pics_hdfs_index} WHERE state=1 LIMIT 1;")
                        print("sql_read:", sql_read)
                        result = conn.execute(sql_read)
                        df = pd.DataFrame(result.fetchall())
                        if df.shape[0]:
                            df.columns = result.keys()
                            self.id = list(df.id)[0] if list(df.id) else None
                            self.current_counts = list(df.current_counts)[0] if list(df.current_counts) else None
                            self.all_counts = list(df.all_counts)[0] if list(df.all_counts) else None
                            self.hdfs_path = list(df.hdfs_path)[0] if list(df.hdfs_path) else None
                            self.hdfs_block_name = re.findall("(part-\d+)-", self.hdfs_path)[0] if self.hdfs_path else None
                            self.local_path = rf"/mnt/image_search/image_id_index/{self.site_name}/{self.hdfs_block_name}"
                            print(f"df.shape:{df.shape}, self.id:{self.id}, self.current_counts:{self.current_counts}, self.all_counts:{self.all_counts}, self.hdfs_path:{self.hdfs_path}")
                            if self.id:
                                os.system(f"hdfs dfs -rm -r /home/big_data_selection/dwd/image_dwd_id_index/site_name={self.site_name}/block={self.hdfs_block_name}")
                                sql_update = text(
                                    f"UPDATE {self.tn_pics_hdfs_index} SET state=2 WHERE id={self.id};")
                                print("sql_update:", sql_update)
                                conn.execute(sql_update)
                        else:
                            quit()
                    #  读取hdfs路径下的parquet文件
                    self.df_features = self.spark.read.text(self.hdfs_path).cache()
                    self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                    return df
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(5)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(5)
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                continue

    def handle_data(self):
        # 创建一个新的 DataFrame,其中每个字段都是一个独立的列
        split_df = self.df_features.select(F.split(self.df_features['value'], '\t').alias('split_values'))
        # 假设你知道你的数据有三个字段
        # 你可以这样创建每个字段的独立列
        final_df = split_df.select(
            split_df['split_values'].getItem(0).alias('id'),
            split_df['split_values'].getItem(1).alias('asin'),
            split_df['split_values'].getItem(2).alias('embedding')
        )
        print("分块前分区数量:", final_df.rdd.getNumPartitions())

        # 从hdfs读取parquet文件,进行split切分的时候是字符串类型-->转换成数值类型
        final_df = final_df.withColumn('id', final_df['id'].cast('bigint'))  # 然后你可以安全地转换
        # 添加索引列
        final_df = final_df.withColumn("index", F.monotonically_increasing_id() + self.all_counts)
        final_df.show()

        # 此处更改数据类型在pictures_dim_features_slice已经做了 -- 由于读取的是lzo文件,而不是直接读表,因此还需要转换类型
        # 定义一个将字符串转换为列表的UDF
        str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
        # 对DataFrame中的列应用这个UDF
        final_df = final_df.withColumn("embedding", str_to_list_udf(final_df["embedding"]))

        # final_df.write.mode('overwrite').parquet("hdfs://hadoop5:8020/home/ffman/parquet")

        # final_df = final_df.withColumn("block", F.lit(self.hdfs_block_name))
        self.df_save = final_df.withColumn("site_name", F.lit(self.site_name))
        self.df_save = self.df_save.withColumn("img_type", F.lit(self.img_type))
        self.df_save.show(10)
        self.df_save_local = self.df_save.select("embedding")

    def save_data_local(self):
        print("当前存储到本地:", self.local_path)
        if os.path.exists(self.local_path):
            os.system(f"rm -rf {self.local_path}")
        os.makedirs(self.local_path)
        # Convert DataFrame to Arrow Table
        df = self.df_save_local.toPandas()
        table = pa.Table.from_pandas(df)
        # Save to Parquet
        pq.write_table(table, f"{self.local_path}/{self.hdfs_block_name}.parquet")

    def update_state_after_save(self):
        with self.engine_srs.begin() as conn:
            sql_update = text(
                f"UPDATE {self.tn_pics_hdfs_index} SET state=3 WHERE state=2 and id={self.id};")
            print("sql_update:", sql_update)
            conn.execute(sql_update)

    def run(self):
        while True:
            try:
                df = self.read_data()
                if df.shape[0]:
                    self.handle_data()
                    self.save_data()
                    # 存储到本地
                    self.save_data_local()
                    self.update_state_after_save()
                else:
                    break
            except Exception as e:
                print(f"error: {e}", traceback.format_exc())
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                time.sleep(10)

    def run_thread(self):
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run)
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")

# def main():
#     handle_obj = PicturesIdIndex()
#     handle_obj.run()


if __name__ == "__main__":
    handle_obj = PicturesIdIndex(thread_num=5)
    handle_obj.run()
    # handle_obj.run_thread()
    quit()


    processes = []
    for _ in range(5):  # 用于设定进程数量
        handle_obj = PicturesIdIndex()
        process = multiprocessing.Process(target=handle_obj.run)
        process.start()
        processes.append(process)

    # 等待所有进程完成
    for process in processes:
        process.join()

# if __name__ == '__main__':
#     # 设置进程数
#     num_processes = 4  # 设置为你需要的进程数
#     # 创建进程池对象
#     pool = Pool(processes=num_processes)
#     # 使用进程池中的进程执行任务
#     pool.apply(main)
#     # 关闭进程池
#     pool.close()
#     # 等待所有进程完成
#     pool.join()
# if __name__ == '__main__':
#     # 创建进程对象
#     process = Process(target=main)
#     # 启动进程
#     process.start()
#     # 等待进程结束
#     process.join()