import os
import sys
import pandas as pd
import numpy as np
import json
import orjson, requests, time
from typing import List, Optional, Union

# DEFAULT_USER = "fangxingjun"
# DEFAULT_USER_TOKEN = "fxj_token_123"
# sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录

try:
    from user import DEFAULT_USER, DEFAULT_USER_TOKEN
except Exception as e:
    from .user import DEFAULT_USER, DEFAULT_USER_TOKEN

# 新增：默认微信通知人（用于任务失败时告警）
NOTIFY_DEFAULT_USERS = ['pengyanbing', 'hezhe', 'chenyuanjie', 'fangxingjun']

# -------- 映射字典 --------
site_name_db_dict = {
    "us": "selection",
    "uk": "selection_uk",
    "de": "selection_de",
    "es": "selection_es",
    "fr": "selection_fr",
    "it": "selection_it",
}

db_type_alias_map = {
    "mysql": "mysql",  # 阿里云mysql
    "postgresql_14": "postgresql_14",  # pg14爬虫库-内网
    "postgresql_14_outer": "postgresql_14_outer",  # pg14爬虫库-外网
    "postgresql_15": "postgresql_15",  # pg15正式库-内网
    "postgresql_15_outer": "postgresql_15_outer",  # pg15正式库-外网
    "postgresql_cluster": "postgresql_cluster",  # pg集群-内网
    "postgresql_cluster_outer": "postgresql_cluster_outer",  # pg集群-外网
    "doris": "doris",  # doris集群-内网
}

DEFAULT_SERVERS = [
    # "http://192.168.200.210:7777",   # 内网-h5
    "http://192.168.200.210:7778",   # 内网-h5
    # "http://192.168.200.210:7778",  # 内网-测试大数据-h5
    # "http://192.168.10.216:7777",  # 内网-测试大数据-h6
    # "http://192.168.10.216:7777",   # 内网-h6
    # "http://192.168.10.217:7777",   # 内网-h7
    # "http://61.145.136.61:7777",   # 外网
    # "http://61.145.136.61:7779",   # 外网
]


# ---------------------------


class RemoteTransaction:

    def __init__(self, db_type: str, site_name: str, session: requests.Session, urls: List[str],
                 user: Optional[str] = None, user_token: Optional[str] = None):
        self.db_type = db_type
        self.site_name = site_name
        self.session = session
        self.urls = urls
        self.sql_queue = []
        self.user = user if user is not None else DEFAULT_USER
        self.user_token = user_token if user_token is not None else DEFAULT_USER_TOKEN

    def execute(self, sql: str, params=None):
        """
        params 可取：
        • None → 纯文本 SQL
        • dict → 单条参数化   e.g. {"id":1,"name":"a"}
        • list/tuple → 批量 executemany
            - list[dict]  ↔ INSERT .. VALUES (:id,:name)
            - list[tuple] ↔ INSERT .. VALUES (%s,%s)
        """
        self.sql_queue.append({"sql": sql, "params": params})

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        for url in self.urls:
            try:
                self.session.post(
                    url + "/transaction",
                    json={"db": self.db_type,
                          "sql_list": self.sql_queue,
                          "site_name": self.site_name,  # site_name not needed on server, kept for clarity
                          "user": self.user,  # ✅ 携带
                          "user_token": self.user_token,  # ✅ 携带
                          },
                    timeout=3000,
                ).raise_for_status()
                return
            except Exception as e:
                print(f"[WARN] 事务失败 {url}: {e}")
        raise RuntimeError("All servers failed for transaction")


class RemoteEngine:
    def __init__(self, db_type: str, site_name: str, server_urls: List[str], retries: int = 2,
                 alert_users: Optional[List[str]] = None,
                 user: Optional[str] = None, user_token: Optional[str] = None):
        self.db_type = db_type
        self.site_name = site_name
        self.urls = [u.rstrip("/") for u in server_urls]
        self.session = requests.Session()
        self.retries = retries
        self.alert_users = alert_users or NOTIFY_DEFAULT_USERS  # ✅ 默认通知人
        self.user = user
        self.user_token = user_token
        if not self.user or not self.user_token:
            raise ValueError("user and user_token are required (or set env MDB_USER / MDB_USER_TOKEN)")

    def _request(self, endpoint: str, payload, timeout=3000):
        # 统一注入 user / user_token
        if isinstance(payload, dict):
            payload.setdefault("user", self.user)
            payload.setdefault("user_token", self.user_token)
        for url in self.urls:
            for _ in range(self.retries):
                try:
                    json_bytes = orjson.dumps(payload)
                    r = self.session.post(f"{url}/{endpoint}",
                                          data=json_bytes,
                                          headers={"Content-Type": "application/json"},
                                          timeout=timeout)

                    # r = self.session.post(f"{url}/{endpoint}",
                    #                       json=payload, timeout=10)
                    r.raise_for_status()
                    return r.json()
                except Exception as e:
                    print(f"[WARN] {endpoint} fail @ {url}: {e}")
                    time.sleep(1)
        raise RuntimeError(f"All servers failed for {endpoint}")

    # def _request(self, endpoint: str, payload, timeout: int = 3000):
    #     """
    #     统一注入 user / user_token，并对非 200 响应解析 detail/error 文本，
    #     让调用方能看到“凭证错误/未授权”等服务端具体原因。
    #     """
    #     # 注入 user / token
    #     if isinstance(payload, dict):
    #         payload.setdefault("user", self.user)
    #         payload.setdefault("user_token", self.user_token)
    #
    #     # 可能有多个服务端，逐个尝试
    #     import json as orjson
    #     json_bytes = orjson.dumps(payload)
    #
    #     errors = []  # 记录每个 URL 的错误摘要
    #     last_exc = None
    #
    #     for url in self.urls:
    #         try:
    #             r = self.session.post(
    #                 f"{url}/{endpoint}",
    #                 data=json_bytes,
    #                 headers={"Content-Type": "application/json"},
    #                 timeout=timeout / 1000.0 if timeout and timeout > 100 else timeout
    #             )
    #
    #             # 200：一定是 JSON
    #             if r.status_code == 200:
    #                 return r.json()
    #
    #             # 非 200：尽量从 JSON 里抽取服务端给的 detail / error
    #             msg = None
    #             try:
    #                 j = r.json()
    #                 msg = j.get("detail") or j.get("error") or j.get("message") or str(j)
    #             except Exception:
    #                 # 不是 JSON 就用纯文本，避免遮蔽真实报错
    #                 msg = (r.text or "").strip()
    #
    #             brief_sql = ""
    #             if isinstance(payload, dict) and payload.get("sql"):
    #                 s = str(payload["sql"]).strip().replace("\n", " ")
    #                 brief_sql = f" sql={s[:200]}{'...' if len(s) > 200 else ''}"
    #
    #             warn = f"[WARN] {endpoint} fail @ {url}: {r.status_code} {msg}{brief_sql}"
    #             print(warn)
    #             errors.append(warn)
    #             continue
    #
    #         except Exception as e:
    #             last_exc = e
    #             warn = f"[WARN] {endpoint} fail @ {url}: {e}"
    #             print(warn)
    #             errors.append(warn)
    #             continue
    #
    #     # 全部失败：把最后一次/全部错误拼到异常里，便于定位
    #     if errors:
    #         raise RuntimeError(" ; ".join(errors))
    #     if last_exc:
    #         raise RuntimeError(str(last_exc))
    #     raise RuntimeError(f"All servers failed for {endpoint}")

    def _notify_textcard(self, title: str, content: str):
        """只用 textcard，失败不影响主流程"""
        payload = {
            "users": self.alert_users,  # 也可不传，由服务端默认；这里显式传入
            "title": title,
            "content": content,
            "msgtype": "textcard",
            "user": self.user,
            "user_token": self.user_token,
        }
        for url in self.urls:
            try:
                self.session.post(f"{url}/notify/wx", json=payload, timeout=8)
                break
            except Exception as e:
                print(f"[WARN] notify fail @ {url}: {e}")

    @staticmethod
    def _short(s: str, n: int = 300) -> str:
        if not s:
            return ""
        s = str(s)
        return s if len(s) <= n else s[-n:]

    def read_sql(self, sql: str) -> pd.DataFrame:
        data = self._request("query",
                             {"db": self.db_type,
                              "sql": sql,
                              "site_name": self.site_name})
        return pd.DataFrame(data["result"])

    def to_sql(self,
               df: pd.DataFrame,
               table: str, if_exists="append"):

        return self._request("insert",
                             {"db": self.db_type,
                              "table": table,
                              "if_exists": if_exists,
                              "data": fully_clean_for_orjson(df=df).to_dict("records"),
                              # "data": df_to_json_records(df),   # ← 清洗后的 records
                              "site_name": self.site_name})

    def read_then_update(
            self,
            select_sql: str,
            update_table: str,
            set_values: dict,
            where_keys: List[str],
            error_if_empty: bool = False,
    ):
        """
        动态生成 UPDATE：把 select_sql 读到的行，按 where_keys 精准更新 set_values
        返回 (DataFrame, rows_updated)
        """
        payload = {
            "db": self.db_type,
            "site_name": self.site_name,
            "select_sql": select_sql,
            "update_table": update_table,
            "set_values": set_values,
            "where_keys": where_keys,
            "error_if_empty": error_if_empty,
        }
        resp = self._request("read_then_update", payload)
        df = pd.DataFrame(resp["read_result"])
        rows_updated = resp.get("rows_updated", 0)
        if df.empty:
            print(f"啥也没查到，就不做更新, read_sql: {select_sql}")
        else:
            print(f"更新 {rows_updated} 行")
        return df


    def sqoop_raw_import(self,
                         # site_name: str,
                         # db_type: str,
                         query: str,
                         hive_table: str,
                         hdfs_path: str = None,  # hdfs路径
                         default_db: str = "big_data_selection",
                         partitions: dict = None,
                         queue_name: str = "default",
                         m: int = 1,
                         split_by: Optional[str] = None,  # 如果多map_num时, 必须要指定split_by字段
                         outdir: str = "/tmp/sqoop/",
                         sqoop_home: str = "/opt/module/sqoop-1.4.6/bin/sqoop",
                         job_name: str = None,
                         dry_run: bool = False,
                         check_count: bool = True,
                         clean_hdfs: bool = True,  # 是否删除hdfs路径, 默认删除
                         timeout_sec: int = 36000):
        print("site_name:", self.site_name, "db_type:", self.db_type)
        body = {
            # "site_name": site_name,
            # "db_type": db_type,
            "site_name": self.site_name,
            "db_type": self.db_type,
            "default_db": default_db,  # 👈 必须明确带上
            "hive_table": hive_table,
            "query": query,
            "partitions": partitions or {},
            "queue_name": queue_name,
            "m": m,
            "split_by": split_by,
            "outdir": outdir,
            "sqoop_home": sqoop_home,
            "job_name": job_name or f"sqoop_task---{hive_table}",
            "dry_run": dry_run,
            "check_count": check_count,  # ✅ 新增
            "hdfs_path": hdfs_path,
            "clean_hdfs": clean_hdfs,
            "timeout_sec": timeout_sec
        }
        resp = self._request("sqoop/raw_import", body, timeout=timeout_sec)

        def _print_step(name, r):
            print(f"\n===== {name} =====")
            if "cmd" in r and r["cmd"]:
                print("CMD:", r["cmd"])
            if r.get("msg"): print("MSG:", r["msg"])
            if r.get("stdout"): print("\nSTDOUT:\n", r["stdout"])
            if r.get("stderr"): print("\nSTDERR:\n", r["stderr"])
            print("OK:", r.get("ok"), " CODE:", r.get("code"))

        sch = resp.get("schema_check_result", {}) or {}
        icr = resp.get("import_check_result", {}) or {}
        ccr = resp.get("count_check_result", {}) or {}

        _print_step("SCHEMA CHECK", sch)
        if sch.get("ok") is not True:
            self._notify_textcard(
                title="Schema 校验失败",
                content=(
                    f"任务：{body.get('job_name')}\n"
                    f"表：{hive_table}\n"
                    f"code：{sch.get('code')}\n"
                    f"msg：{self._short(sch.get('msg'))}"
                )
            )
            raise RuntimeError("Schema 校验失败。详见上面日志。")

        _print_step("SQOOP IMPORT", icr)
        if icr.get("ok") is not True:
            self._notify_textcard(
                title="Sqoop 导入失败",
                content=(
                    f"任务：{body.get('job_name')}\n"
                    f"表：{hive_table}\n"
                    f"code：{icr.get('code')}\n"
                    f"msg：{self._short(icr.get('msg'))}\n"
                    f"stderr：{self._short(icr.get('stderr'))}"
                )
            )
            raise RuntimeError("Sqoop 导入失败。详见上面日志。")

        _print_step("COUNT CHECK", ccr)
        if ccr.get("ok") is not True:
            self._notify_textcard(
                title="数量校验失败",
                content=(
                    f"任务：{body.get('job_name')}\n"
                    f"表：{hive_table}\n"
                    f"db_count：{ccr.get('db_count')}\n"
                    f"hive_count：{ccr.get('hive_count')}\n"
                    f"msg：{self._short(ccr.get('msg'))}"
                )
            )
            raise RuntimeError("数量校验失败。详见上面日志。")

        print("\n✅ 全部步骤通过。")

    def sqoop_raw_export(self,
                         hive_table: str,
                         import_table: str,
                         partitions: dict = None,
                         default_db: str = "big_data_selection",
                         queue_name: str = "default",
                         m: int = 1,
                         cols: str = None,
                         outdir: str = "/tmp/sqoop/",
                         sqoop_home: str = "/opt/module/sqoop-1.4.6/bin/sqoop",
                         dry_run: bool = False,
                         timeout_sec: int = 36000):
        """
        Hive → MySQL / PostgreSQL 导出
        """
        print("site_name:", self.site_name, "db_type:", self.db_type)
        body = {
            "site_name": self.site_name,
            "db_type": self.db_type,
            "hive_table": hive_table,
            "import_table": import_table,
            "default_db": default_db,
            "partitions": partitions or {},
            "queue_name": queue_name,
            "m": m,
            "cols": cols,
            "outdir": outdir,
            "sqoop_home": sqoop_home,
            "dry_run": dry_run,
            "timeout_sec": timeout_sec
        }
        resp = self._request("sqoop/raw_export", body, timeout=timeout_sec)

        print("\n===== SQOOP EXPORT =====")
        if "cmd" in resp and resp["cmd"]:
            print("CMD:", resp["cmd"])
        if resp.get("msg"):
            print("MSG:", resp["msg"])
        if resp.get("stdout"):
            print("\nSTDOUT:\n", resp["stdout"])
        if resp.get("stderr"):
            print("\nSTDERR:\n", resp["stderr"])
        print("OK:", resp.get("ok"), " CODE:", resp.get("code"))

        if not resp.get("ok", False):
            self._notify_textcard(
                title="Sqoop 导出失败",
                content=(
                    f"任务：{hive_table} → {import_table}\n"
                    f"code：{resp.get('code')}\n"
                    f"msg：{self._short(resp.get('msg'))}\n"
                    f"stderr：{self._short(resp.get('stderr'))}"
                )
            )
            raise RuntimeError("Sqoop 导出失败。详见上面日志。")

        print("\n✅ Sqoop export 成功。")

    def begin(self):
        return RemoteTransaction(self.db_type, self.site_name,
                                 self.session, self.urls,
                                 user=self.user, user_token=self.user_token)

    def execute(self, sql: str, params: Optional[Union[dict, list, tuple]] = None, as_df: bool = True):
        """
        执行任意 SQL。
        - params=None           → 纯文本执行
        - params=dict           → 单条参数化执行
        - params=list/tuple     → executemany 批量执行（元素 tuple/list）
        - param as_df: 是否返回 DataFrame（仅对 SELECT 有效）, 默认为True
        返回服务端原始结构：
          成功：{"ok": True, ...}
          失败：{"ok": False, "error": "...", "trace": "...", "sql": "...", "params": ...}
        """
        payload = {
            "db": self.db_type,  # 注意用你类里对应字段名
            "site_name": self.site_name,
            "sql": sql,
            "params": params,
        }
        resp = self._request("execute", payload)

        # 服务端保证总是 200 + JSON，这里只需判断 ok
        if not resp.get("ok", False):
            # 把错误详情展示给使用者
            print("\n===== EXECUTE ERROR =====")
            print("SQL:", resp.get("sql"))
            print("PARAMS:", resp.get("params"))
            print("ERROR:", resp.get("error"))
            if resp.get("trace"):
                print("\nTRACE (tail):\n", resp["trace"])
            # 你也可以在这里发 textcard 通知
            raise RuntimeError(resp.get("error") or "execute failed")

        # ✅ 如果是 SELECT，并且 as_df=True，则直接返回 DataFrame
        if as_df and resp.get("returns_rows") and "result" in resp:
            return pd.DataFrame(resp["result"])

        return resp


def get_remote_engine(site_name: str, db_type: str, servers: List[str] = None,
                      user: Optional[str] = None, user_token: Optional[str] = None) -> RemoteEngine:
    if site_name not in site_name_db_dict:
        raise ValueError(f"Unknown site_name: {site_name}")
    if db_type not in db_type_alias_map:
        raise ValueError(f"Unknown db_type: {db_type}")

    # 如果调用方没传，则用文件顶部默认值
    user = user if user is not None else DEFAULT_USER
    user_token = user_token if user_token is not None else DEFAULT_USER_TOKEN

    # 这里做一次最小校验，避免忘配
    if not user or not user_token:
        raise ValueError(
            "user and user_token are required (set in file top DEFAULT_USER/DEFAULT_USER_TOKEN, or pass explicitly)")

    return RemoteEngine(
        db_type=db_type_alias_map[db_type],
        site_name=site_name,
        server_urls=servers or DEFAULT_SERVERS,
        user=user, user_token=user_token,  # ✅
    )


### 工具类函数
def df_to_json_records(df: pd.DataFrame) -> list:
    """保证 DataFrame 可安全序列化为 JSON records（处理 NaN / ±Inf）"""
    df_clean = df.copy()

    # 1️⃣ 替换 ±Inf -> NaN
    num_cols = df_clean.select_dtypes(include=[np.number]).columns
    if len(num_cols):
        df_clean[num_cols] = df_clean[num_cols].replace([np.inf, -np.inf], np.nan)

    # 2️⃣ 替换 NaN -> None（注意：有时 astype(object) 不彻底，需用 applymap）
    df_clean = df_clean.applymap(lambda x: None if pd.isna(x) else x)

    # 3️⃣ 转为 dict records
    return df_clean.to_dict("records")


def clean_json_field_for_orjson(v):
    """清洗单个 JSON 字段的值，使其符合 orjson 要求并避免空字典入库"""
    if v is None or pd.isna(v):
        return None

    # 1️⃣ 如果是空字典对象，返回 None
    if isinstance(v, dict) and not v:
        return None

    # 2️⃣ 如果是空字符串或仅为 "{}"，返回 None
    if isinstance(v, str):
        stripped = v.strip()
        if not stripped or stripped == "{}":
            return None
        try:
            parsed = json.loads(stripped)
            if isinstance(parsed, dict) and not parsed:
                return None
            return json.dumps(parsed, ensure_ascii=False)
        except Exception:
            return v  # 非 JSON 字符串则原样保留

    return v


# def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
#     """全面清洗 DataFrame 以符合 orjson 要求"""
#     df = df.replace([np.inf, -np.inf], np.nan)
#     df = df.applymap(lambda x: None if pd.isna(x) else x)
#
#     # 找出所有可能为 JSON 字符串的字段
#     json_like_cols = [col for col in df.columns if col.endswith('_json')]
#
#     # 针对每个 JSON-like 字段，应用清洗函数
#     for col in json_like_cols:
#         df[col] = df[col].apply(clean_json_field_for_orjson)
#
#     return df

def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
    # """全面清洗 DataFrame 以符合 orjson 要求"""
    # df = df.replace([np.inf, -np.inf], np.nan)
    # df = df.applymap(lambda x: None if pd.isna(x) else x)
    #
    # # 找出所有可能为 JSON 字符串的字段
    # json_like_cols = [col for col in df.columns if col.endswith('_json')]
    #
    # # 针对每个 JSON-like 字段，应用清洗函数
    # for col in json_like_cols:
    #     df[col] = df[col].apply(clean_json_field_for_orjson)
    #
    # return df
    """全面清洗 DataFrame 以符合 orjson 要求"""
    df = df.replace([np.inf, -np.inf], np.nan)

    # NaN → None （比 applymap 高效且不出错）
    df = df.where(pd.notna(df), None)

    # 找出所有可能为 JSON 字符串的字段
    json_like_cols = [col for col in df.columns if col.endswith('_json')]

    # 针对每个 JSON-like 字段，应用清洗函数
    for col in json_like_cols:
        df[col] = df[col].apply(clean_json_field_for_orjson)

    return df