import json

import pandas as pd
import numpy as np
import orjson, requests, time
from typing import List

# -------- 映射字典 --------
site_name_db_dict = {
    "us": "selection",
    "uk": "selection_uk",
    "de": "selection_de",
    "es": "selection_es",
    "fr": "selection_fr",
    "it": "selection_it",
}

db_type_alias_map = {
    "mysql": "mysql",  # 阿里云mysql
    "postgresql_14": "postgresql_14",  # pg14爬虫库-内网
    "postgresql_14_outer": "postgresql_14_outer",  # pg14爬虫库-外网
    "postgresql_15": "postgresql_15",  # pg15正式库-内网
    "postgresql_15_outer": "postgresql_15_outer",  # pg15正式库-外网
    "postgresql_cluster": "postgresql_cluster",  # pg集群-内网
    "postgresql_cluster_outer": "postgresql_cluster_outer",  # pg集群-外网
    "doris": "doris",  # doris集群-内网
}

DEFAULT_SERVERS = [
    # "http://192.168.200.210:7777",   # 内网
    # "http://192.168.10.217:7777",   # 内网-h7
    "http://61.145.136.61:7777",   # 外网
    "http://61.145.136.61:7779",   # 外网

]
# ---------------------------

def df_to_json_records(df: pd.DataFrame) -> list:
    """保证 DataFrame 可安全序列化为 JSON records（处理 NaN / ±Inf）"""
    df_clean = df.copy()

    # 1️⃣ 替换 ±Inf -> NaN
    num_cols = df_clean.select_dtypes(include=[np.number]).columns
    if len(num_cols):
        df_clean[num_cols] = df_clean[num_cols].replace([np.inf, -np.inf], np.nan)

    # 2️⃣ 替换 NaN -> None（注意：有时 astype(object) 不彻底，需用 applymap）
    df_clean = df_clean.applymap(lambda x: None if pd.isna(x) else x)

    # 3️⃣ 转为 dict records
    return df_clean.to_dict("records")


def clean_json_field_for_orjson(v):
    """清洗单个 JSON 字段的值，使其符合 orjson 要求并避免空字典入库"""
    if v is None or pd.isna(v):
        return None

    # 1️⃣ 如果是空字典对象，返回 None
    if isinstance(v, dict) and not v:
        return None

    # 2️⃣ 如果是空字符串或仅为 "{}"，返回 None
    if isinstance(v, str):
        stripped = v.strip()
        if not stripped or stripped == "{}":
            return None
        try:
            parsed = json.loads(stripped)
            if isinstance(parsed, dict) and not parsed:
                return None
            return json.dumps(parsed, ensure_ascii=False)
        except Exception:
            return v  # 非 JSON 字符串则原样保留

    return v


def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
    """全面清洗 DataFrame 以符合 orjson 要求"""
    df = df.replace([np.inf, -np.inf], np.nan)
    df = df.applymap(lambda x: None if pd.isna(x) else x)

    # 找出所有可能为 JSON 字符串的字段
    json_like_cols = [col for col in df.columns if col.endswith('_json')]

    # 针对每个 JSON-like 字段，应用清洗函数
    for col in json_like_cols:
        df[col] = df[col].apply(clean_json_field_for_orjson)

    return df


class RemoteTransaction:

    def __init__(self, db: str, database: str,
                 session: requests.Session, urls: List[str]):
        self.db = db
        self.database = database
        self.session = session
        self.urls = urls
        self.sql_queue = []

    # def execute(self, sql: str):
    #     self.sql_queue.append(sql)
    def execute(self, sql: str, params=None):
        """
        params 可取：
        • None → 纯文本 SQL
        • dict → 单条参数化   e.g. {"id":1,"name":"a"}
        • list/tuple → 批量 executemany
            - list[dict]  ↔ INSERT .. VALUES (:id,:name)
            - list[tuple] ↔ INSERT .. VALUES (%s,%s)
        """
        self.sql_queue.append({"sql": sql, "params": params})

    def __enter__(self): return self

    def __exit__(self, exc_type, exc, tb):
        for url in self.urls:
            try:
                self.session.post(
                    url + "/transaction",
                    json={"db": self.db,
                          "sql_list": self.sql_queue,
                          "site_name": self.database},  # site_name not needed on server, kept for clarity
                    timeout=3000,
                ).raise_for_status()
                return
            except Exception as e:
                print(f"[WARN] 事务失败 {url}: {e}")
        raise RuntimeError("All servers failed for transaction")


class RemoteEngine:
    def __init__(self, db: str, database: str,
                 server_urls: List[str], retries: int = 2):
        self.db = db
        self.database = database
        self.urls = [u.rstrip("/") for u in server_urls]
        self.session = requests.Session()
        self.retries = retries

    def _request(self, endpoint: str, payload):
        for url in self.urls:
            for _ in range(self.retries):
                try:
                    json_bytes = orjson.dumps(payload)
                    r = self.session.post(f"{url}/{endpoint}",
                                          data=json_bytes,
                                          headers={"Content-Type": "application/json"},
                                          timeout=3000)

                    # r = self.session.post(f"{url}/{endpoint}",
                    #                       json=payload, timeout=10)
                    r.raise_for_status()
                    return r.json()
                except Exception as e:
                    print(f"[WARN] {endpoint} fail @ {url}: {e}")
                    time.sleep(1)
        raise RuntimeError(f"All servers failed for {endpoint}")
    # def _request(self, endpoint: str, payload):
    #     # 用 orjson，“allow_nan” 会把 NaN/Inf 写成 null
    #     # json_bytes = orjson.dumps(payload,
    #     #                           option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_OMIT_MICROSECOND | orjson.OPT_ALLOW_INF_AND_NAN)
    #     json_bytes = orjson.dumps(
    #         payload,
    #         option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_ALLOW_INF_AND_NAN
    #     )
    #
    #     headers = {"Content-Type": "application/json"}
    #
    #     for url in self.urls:
    #         for _ in range(self.retries):
    #             try:
    #                 r = self.session.post(f"{url}/{endpoint}",
    #                                       data=json_bytes, headers=headers,
    #                                       timeout=15)
    #                 r.raise_for_status()
    #                 return r.json()
    #             except Exception as e:
    #                 print(f"[WARN] {endpoint} fail @ {url}: {e}")
    #                 time.sleep(1)
    #     raise RuntimeError(f"All servers failed for {endpoint}")

    # ---------- 公共 API ----------
    def read_sql(self, sql: str) -> pd.DataFrame:
        data = self._request("query",
                             {"db": self.db,
                              "sql": sql,
                              "site_name": self.database})
        return pd.DataFrame(data["result"])

    def to_sql(self, df: pd.DataFrame, table: str, if_exists="append"):

        return self._request("insert",
                             {"db": self.db,
                              "table": table,
                              "if_exists": if_exists,
                              "data": fully_clean_for_orjson(df=df).to_dict("records"),
                              # "data": df_to_json_records(df),   # ← 清洗后的 records
                              "site_name": self.database})

    def read_then_update(
            self,
            select_sql: str,
            update_table: str,
            set_values: dict,
            where_keys: List[str],
            error_if_empty: bool = False,
    ):
        """
        动态生成 UPDATE：把 select_sql 读到的行，按 where_keys 精准更新 set_values
        返回 (DataFrame, rows_updated)
        """
        payload = {
            "db": self.db,
            "site_name": self.database,
            "select_sql": select_sql,
            "update_table": update_table,
            "set_values": set_values,
            "where_keys": where_keys,
            "error_if_empty": error_if_empty,
        }
        resp = self._request("read_then_update", payload)
        df = pd.DataFrame(resp["read_result"])
        rows_updated = resp.get("rows_updated", 0)
        return df

    def begin(self):
        return RemoteTransaction(self.db, self.database,
                                 self.session, self.urls)
# ---------------------------------


def get_remote_engine(site_name: str, db_type: str,
                      servers: List[str] = None) -> RemoteEngine:
    if site_name not in site_name_db_dict:
        raise ValueError(f"Unknown site_name: {site_name}")
    if db_type not in db_type_alias_map:
        raise ValueError(f"Unknown db_type: {db_type}")
    return RemoteEngine(
        db=db_type_alias_map[db_type],
        database=site_name,
        server_urls=servers or DEFAULT_SERVERS,
    )
