import ast
import logging
import os
import re
import sys
import threading
import time
import traceback

import pandas as pd
import redis
from pyspark.sql.types import ArrayType, FloatType

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing
from utils.db_util import DbTypes, DBUtil


class PicturesIdIndex(Templates):

    def __init__(self, site_name='us', asin_type=1, thread_num=10):
        super(PicturesIdIndex, self).__init__()
        self.site_name = site_name
        self.asin_type = asin_type
        self.thread_num = thread_num
        self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
        self.db_save = f'image_dwd_id_index'
        self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
        self.df_features = self.spark.sql(f"select 1+1;")
        self.df_save = self.spark.sql(f"select 1+1;")
        self.df_save_local = self.spark.sql(f"select 1+1;")
        self.partitions_by = ['site_name', 'asin_type']
        self.partitions_num = 1
        # mysql表获取的相关变量
        self.tn_pics_hdfs_index = f"image_hdfs_index"
        self.id = int()
        self.current_counts = int()
        self.all_counts = int()
        self.hdfs_path = str()
        self.hdfs_block_name = str()
        self.local_path = str()
        self.local_name = self.db_save

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
                if lock_acquired:
                    # 读取mysql的hdfs_path路径信息
                    with self.engine_srs.begin() as conn:
                        sql_read = text(f"SELECT * FROM {self.tn_pics_hdfs_index} WHERE state=1 LIMIT 1;")
                        print("sql_read:", sql_read)
                        result = conn.execute(sql_read)
                        df = pd.DataFrame(result.fetchall())
                        if df.shape[0]:
                            df.columns = result.keys()
                            self.id = list(df.id)[0] if list(df.id) else None
                            self.current_counts = list(df.current_counts)[0] if list(df.current_counts) else None
                            self.all_counts = list(df.all_counts)[0] if list(df.all_counts) else None
                            self.hdfs_path = list(df.hdfs_path)[0] if list(df.hdfs_path) else None
                            self.hdfs_block_name = re.findall("(part-\d+)-", self.hdfs_path)[0] if self.hdfs_path else None
                            self.local_path = rf"/mnt/image_search/image_id_index/{self.site_name}/{self.hdfs_block_name}"
                            print(f"df.shape:{df.shape}, self.id:{self.id}, self.current_counts:{self.current_counts}, self.all_counts:{self.all_counts}, self.hdfs_path:{self.hdfs_path}")
                            if self.id:
                                os.system(f"hdfs dfs -rm -r /home/big_data_selection/dwd/image_dwd_id_index/site_name={self.site_name}/block={self.hdfs_block_name}")
                                sql_update = text(
                                    f"UPDATE {self.tn_pics_hdfs_index} SET state=2 WHERE id={self.id};")
                                print("sql_update:", sql_update)
                                conn.execute(sql_update)
                        else:
                            quit()
                    #  读取hdfs路径下的parquet文件
                    self.df_features = self.spark.read.text(self.hdfs_path).cache()
                    self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                    return df
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(5)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(5)
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                continue

    def handle_data(self):
        # 创建一个新的 DataFrame,其中每个字段都是一个独立的列
        split_df = self.df_features.select(F.split(self.df_features['value'], '\t').alias('split_values'))
        # 假设你知道你的数据有三个字段
        # 你可以这样创建每个字段的独立列
        final_df = split_df.select(
            split_df['split_values'].getItem(0).alias('id'),
            split_df['split_values'].getItem(1).alias('asin'),
            split_df['split_values'].getItem(2).alias('embedding')
        )
        print("分块前分区数量:", final_df.rdd.getNumPartitions())

        # 从hdfs读取parquet文件,进行split切分的时候是字符串类型-->转换成数值类型
        final_df = final_df.withColumn('id', final_df['id'].cast('bigint'))  # 然后你可以安全地转换
        # 添加索引列
        final_df = final_df.withColumn("index", F.monotonically_increasing_id() + self.all_counts)
        final_df.show()

        # 此处更改数据类型在pictures_dim_features_slice已经做了 -- 由于读取的是lzo文件,而不是直接读表,因此还需要转换类型
        # 定义一个将字符串转换为列表的UDF
        str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
        # 对DataFrame中的列应用这个UDF
        final_df = final_df.withColumn("embedding", str_to_list_udf(final_df["embedding"]))

        # final_df.write.mode('overwrite').parquet("hdfs://hadoop5:8020/home/ffman/parquet")

        # final_df = final_df.withColumn("block", F.lit(self.hdfs_block_name))
        self.df_save = final_df.withColumn("site_name", F.lit(self.site_name))
        self.df_save = self.df_save.withColumn("asin_type", F.lit(self.asin_type))
        self.df_save.show(10)
        self.df_save_local = self.df_save.select("embedding")

    def save_data_local(self):
        print("当前存储到本地:", self.local_path)
        if os.path.exists(self.local_path):
            os.system(f"rm -rf {self.local_path}")
        os.makedirs(self.local_path)
        # Convert DataFrame to Arrow Table
        df = self.df_save_local.toPandas()
        table = pa.Table.from_pandas(df)
        # Save to Parquet
        pq.write_table(table, f"{self.local_path}/{self.hdfs_block_name}.parquet")

    def update_state_after_save(self):
        with self.engine_srs.begin() as conn:
            sql_update = text(
                f"UPDATE {self.tn_pics_hdfs_index} SET state=3 WHERE state=2 and id={self.id};")
            print("sql_update:", sql_update)
            conn.execute(sql_update)

    def run(self):
        while True:
            try:
                df = self.read_data()
                if df.shape[0]:
                    self.handle_data()
                    self.save_data()
                    # 存储到本地
                    self.save_data_local()
                    self.update_state_after_save()
                else:
                    break
            except Exception as e:
                print(f"error: {e}", traceback.format_exc())
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                time.sleep(10)

    def run_thread(self):
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run)
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")

# def main():
#     handle_obj = PicturesIdIndex()
#     handle_obj.run()


if __name__ == "__main__":
    handle_obj = PicturesIdIndex(thread_num=5)
    handle_obj.run()
    # handle_obj.run_thread()
    quit()


    processes = []
    for _ in range(5):  # 用于设定进程数量
        handle_obj = PicturesIdIndex()
        process = multiprocessing.Process(target=handle_obj.run)
        process.start()
        processes.append(process)

    # 等待所有进程完成
    for process in processes:
        process.join()

# if __name__ == '__main__':
#     # 设置进程数
#     num_processes = 4  # 设置为你需要的进程数
#     # 创建进程池对象
#     pool = Pool(processes=num_processes)
#     # 使用进程池中的进程执行任务
#     pool.apply(main)
#     # 关闭进程池
#     pool.close()
#     # 等待所有进程完成
#     pool.join()
# if __name__ == '__main__':
#     # 创建进程对象
#     process = Process(target=main)
#     # 启动进程
#     process.start()
#     # 等待进程结束
#     process.join()