Commit b96fb565 by fangxingjun

no message

parent 00526099
import json
import pandas as pd
import numpy as np
import orjson, requests, time
from typing import List
# -------- 映射字典 --------
site_name_db_dict = {
"us": "selection",
"uk": "selection_uk",
"de": "selection_de",
"es": "selection_es",
"fr": "selection_fr",
"it": "selection_it",
}
db_type_alias_map = {
"mysql": "mysql", # 阿里云mysql
"postgresql_14": "postgresql_14", # pg14爬虫库-内网
"postgresql_14_outer": "postgresql_14_outer", # pg14爬虫库-外网
"postgresql_15": "postgresql_15", # pg15正式库-内网
"postgresql_15_outer": "postgresql_15_outer", # pg15正式库-外网
"postgresql_cluster": "postgresql_cluster", # pg集群-内网
"postgresql_cluster_outer": "postgresql_cluster_outer", # pg集群-外网
"doris": "doris", # doris集群-内网
}
DEFAULT_SERVERS = [
# "http://192.168.200.210:7777", # 内网
# "http://192.168.10.217:7777", # 内网-h7
"http://61.145.136.61:7777", # 外网
"http://61.145.136.61:7779", # 外网
]
# ---------------------------
def df_to_json_records(df: pd.DataFrame) -> list:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean = df.copy()
# 1️⃣ 替换 ±Inf -> NaN
num_cols = df_clean.select_dtypes(include=[np.number]).columns
if len(num_cols):
df_clean[num_cols] = df_clean[num_cols].replace([np.inf, -np.inf], np.nan)
# 2️⃣ 替换 NaN -> None(注意:有时 astype(object) 不彻底,需用 applymap)
df_clean = df_clean.applymap(lambda x: None if pd.isna(x) else x)
# 3️⃣ 转为 dict records
return df_clean.to_dict("records")
def clean_json_field_for_orjson(v):
"""清洗单个 JSON 字段的值,使其符合 orjson 要求并避免空字典入库"""
if v is None or pd.isna(v):
return None
# 1️⃣ 如果是空字典对象,返回 None
if isinstance(v, dict) and not v:
return None
# 2️⃣ 如果是空字符串或仅为 "{}",返回 None
if isinstance(v, str):
stripped = v.strip()
if not stripped or stripped == "{}":
return None
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict) and not parsed:
return None
return json.dumps(parsed, ensure_ascii=False)
except Exception:
return v # 非 JSON 字符串则原样保留
return v
def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
"""全面清洗 DataFrame 以符合 orjson 要求"""
df = df.replace([np.inf, -np.inf], np.nan)
df = df.applymap(lambda x: None if pd.isna(x) else x)
# 找出所有可能为 JSON 字符串的字段
json_like_cols = [col for col in df.columns if col.endswith('_json')]
# 针对每个 JSON-like 字段,应用清洗函数
for col in json_like_cols:
df[col] = df[col].apply(clean_json_field_for_orjson)
return df
class RemoteTransaction:
def __init__(self, db: str, database: str,
session: requests.Session, urls: List[str]):
self.db = db
self.database = database
self.session = session
self.urls = urls
self.sql_queue = []
# def execute(self, sql: str):
# self.sql_queue.append(sql)
def execute(self, sql: str, params=None):
"""
params 可取:
• None → 纯文本 SQL
• dict → 单条参数化 e.g. {"id":1,"name":"a"}
• list/tuple → 批量 executemany
- list[dict] ↔ INSERT .. VALUES (:id,:name)
- list[tuple] ↔ INSERT .. VALUES (%s,%s)
"""
self.sql_queue.append({"sql": sql, "params": params})
def __enter__(self): return self
def __exit__(self, exc_type, exc, tb):
for url in self.urls:
try:
self.session.post(
url + "/transaction",
json={"db": self.db,
"sql_list": self.sql_queue,
"site_name": self.database}, # site_name not needed on server, kept for clarity
timeout=3000,
).raise_for_status()
return
except Exception as e:
print(f"[WARN] 事务失败 {url}: {e}")
raise RuntimeError("All servers failed for transaction")
class RemoteEngine:
def __init__(self, db: str, database: str,
server_urls: List[str], retries: int = 2):
self.db = db
self.database = database
self.urls = [u.rstrip("/") for u in server_urls]
self.session = requests.Session()
self.retries = retries
def _request(self, endpoint: str, payload):
for url in self.urls:
for _ in range(self.retries):
try:
json_bytes = orjson.dumps(payload)
r = self.session.post(f"{url}/{endpoint}",
data=json_bytes,
headers={"Content-Type": "application/json"},
timeout=3000)
# r = self.session.post(f"{url}/{endpoint}",
# json=payload, timeout=10)
r.raise_for_status()
return r.json()
except Exception as e:
print(f"[WARN] {endpoint} fail @ {url}: {e}")
time.sleep(1)
raise RuntimeError(f"All servers failed for {endpoint}")
# def _request(self, endpoint: str, payload):
# # 用 orjson,“allow_nan” 会把 NaN/Inf 写成 null
# # json_bytes = orjson.dumps(payload,
# # option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_OMIT_MICROSECOND | orjson.OPT_ALLOW_INF_AND_NAN)
# json_bytes = orjson.dumps(
# payload,
# option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_ALLOW_INF_AND_NAN
# )
#
# headers = {"Content-Type": "application/json"}
#
# for url in self.urls:
# for _ in range(self.retries):
# try:
# r = self.session.post(f"{url}/{endpoint}",
# data=json_bytes, headers=headers,
# timeout=15)
# r.raise_for_status()
# return r.json()
# except Exception as e:
# print(f"[WARN] {endpoint} fail @ {url}: {e}")
# time.sleep(1)
# raise RuntimeError(f"All servers failed for {endpoint}")
# ---------- 公共 API ----------
def read_sql(self, sql: str) -> pd.DataFrame:
data = self._request("query",
{"db": self.db,
"sql": sql,
"site_name": self.database})
return pd.DataFrame(data["result"])
def to_sql(self, df: pd.DataFrame, table: str, if_exists="append"):
return self._request("insert",
{"db": self.db,
"table": table,
"if_exists": if_exists,
"data": fully_clean_for_orjson(df=df).to_dict("records"),
# "data": df_to_json_records(df), # ← 清洗后的 records
"site_name": self.database})
def read_then_update(
self,
select_sql: str,
update_table: str,
set_values: dict,
where_keys: List[str],
error_if_empty: bool = False,
):
"""
动态生成 UPDATE:把 select_sql 读到的行,按 where_keys 精准更新 set_values
返回 (DataFrame, rows_updated)
"""
payload = {
"db": self.db,
"site_name": self.database,
"select_sql": select_sql,
"update_table": update_table,
"set_values": set_values,
"where_keys": where_keys,
"error_if_empty": error_if_empty,
}
resp = self._request("read_then_update", payload)
df = pd.DataFrame(resp["read_result"])
rows_updated = resp.get("rows_updated", 0)
return df
def begin(self):
return RemoteTransaction(self.db, self.database,
self.session, self.urls)
# ---------------------------------
def get_remote_engine(site_name: str, db_type: str,
servers: List[str] = None) -> RemoteEngine:
if site_name not in site_name_db_dict:
raise ValueError(f"Unknown site_name: {site_name}")
if db_type not in db_type_alias_map:
raise ValueError(f"Unknown db_type: {db_type}")
return RemoteEngine(
db=db_type_alias_map[db_type],
database=site_name,
server_urls=servers or DEFAULT_SERVERS,
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment