import socket
from func_timeout import func_timeout, FunctionTimedOut
import pandas as pd
from sqlalchemy.exc import OperationalError
from py_spider.utils.secure_db_client import get_remote_engine
from loguru import logger
import time

# 防止底层握手阶段无限等待
socket.setdefaulttimeout(30)
def sql_execute_agg(f_name, sql_or_table, data=None, site="us", db="mysql"):
    engine = get_remote_engine(site_name=site, db_type=db)

    try:
        if f_name == "to_sql":
            if data is None or not isinstance(data, pd.DataFrame):
                raise ValueError("to_sql 操作必须提供 DataFrame 数据")
            engine.to_sql(data, table=sql_or_table, if_exists="append")
            return True
        elif f_name == "read_sql":
            df = engine.read_sql(sql_or_table)
            return df
        elif f_name == "sql_execute":
            with engine.begin() as conn:
                conn.execute(sql_or_table, data)
            return True
    except OperationalError as e:
        logger.error(f"OperationalError sql_or_table is {sql_or_table}")
        return False
    except RuntimeError as e:
        logger.error(f"RuntimeError sql_or_table is {sql_or_table}")
        return False


def sql_try_again(f_name, sql_or_table, data=None, site="us", db="mysql", max_timeout=15):
    fail_count = 0
    while True:
        try:
            # 使用 func_timeout 强制限制执行时间
            # 如果 sql_execute_agg 在 max_timeout 秒内没反应，直接抛出 FunctionTimedOut
            result = func_timeout(max_timeout,sql_execute_agg, args=(f_name, sql_or_table, data, site, db))
            # --- 成功处理逻辑 ---
            if f_name == 'read_sql':
                if isinstance(result, pd.DataFrame):
                    # logger.success(f"SQL读取成功: {sql_or_table[:30]}...")
                    return result
                else:
                    raise ValueError("返回类型不是 DataFrame")
            else:
                if result is True:
                    # logger.success(f"SQL执行成功: {f_name}")
                    return True
                else:
                    raise ValueError("执行返回 False")

        except FunctionTimedOut:
            fail_count += 1
            logger.error(f" 数据库操作超时 (强制中断) - 第 {fail_count} 次重试")
            # 必须休眠！给网络恢复的时间
            time.sleep(5)
            continue

        except (OperationalError, Exception) as e:
            fail_count += 1
            # 捕获所有异常（包括网络断开、SQL报错等）
            logger.warning(f"数据库报错 (第 {fail_count} 次): {e}")
            # 【关键】休眠！防止死循环把 CPU 跑满
            # 可以做一个简单的策略：前几次快点重试，后面慢点
            sleep_time = 3 if fail_count < 5 else 5
            time.sleep(sleep_time)
            continue
