import multiprocessing
import os
import sys
import time
import traceback

import pandas as pd

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
from utils.db_util import DbTypes, DBUtil


class JudgeFinished(Templates):
    def __init__(self, site_name='us', img_type="amazon_inv"):
        super(JudgeFinished, self).__init__()
        self.site_name = site_name
        self.img_type = img_type
        self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
        self.tn_pics_hdfs_index = f"img_hdfs_index"

    def judge(self):
        sql = f"select * from {self.tn_pics_hdfs_index} where state in (1, 2) and site_name='{self.site_name}' and img_type='{self.img_type}';"
        df = pd.read_sql(sql, con=self.engine_doris)
        print(f"sql: {sql}, {df.shape}")
        result_flag = True if df.shape[0] else False
        return result_flag


def main(site_name='us', img_type='amazon_inv', p_num=0):
    while True:
        try:
            judge_obj = JudgeFinished(site_name=site_name, img_type=img_type)
            result_flag = judge_obj.judge()
            if result_flag:
                print(f"继续, result_flag: {result_flag}")
                os.system(f"/opt/module/spark/bin/spark-submit --master yarn --driver-memory 5g --executor-memory 10g --executor-cores 2 --num-executors 1 --queue spark /opt/module/spark/demo/py_demo/img_search/img_dwd_id_index.py {site_name} {img_type}")
            else:
                print(f"结束, result_flag: {result_flag}")
                break
        except Exception as e:
            print(e, traceback.format_exc())
            time.sleep(20)
            error = "ValueError: Length mismatch: Expected axis has 0 elements"
            if error in e:
                print(f"当前已经跑完所有block块id对应的index关系，退出进程-{p_num}")
                quit()
            continue


if __name__ == "__main__":
    site_name = sys.argv[1]
    img_type = sys.argv[2]
    process_num = int(sys.argv[3])  # 参数1：进程数

    processes = []
    for p_num in range(process_num):  # 用于设定进程数量
        process = multiprocessing.Process(target=main, args=(site_name, img_type, p_num))
        process.start()
        processes.append(process)

    # 等待所有进程完成
    for process in processes:
        process.join()
