Commit 7293af95 by fangxingjun

no message

parent dbb59488
import json
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import json
import orjson, requests, time import orjson, requests, time
from typing import List, Optional from typing import List, Optional
mima = 'xsfasg' mima = 'xsfasg'
# 新增:默认微信通知人(用于任务失败时告警)
NOTIFY_DEFAULT_USERS = ['fangxingjun', 'chenyuanjie']
# -------- 映射字典 -------- # -------- 映射字典 --------
site_name_db_dict = { site_name_db_dict = {
"us": "selection", "us": "selection",
...@@ -30,83 +32,27 @@ db_type_alias_map = { ...@@ -30,83 +32,27 @@ db_type_alias_map = {
DEFAULT_SERVERS = [ DEFAULT_SERVERS = [
# "http://192.168.200.210:7777", # 内网-h5 # "http://192.168.200.210:7777", # 内网-h5
"http://192.168.200.210:7778", # 内网-测试大数据-h5 # "http://192.168.200.210:7778", # 内网-测试大数据-h5
"http://192.168.10.216:7777", # 内网-测试大数据-h6
# "http://192.168.10.216:7777", # 内网-h6 # "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.217:7777", # 内网-h7 # "http://192.168.10.217:7777", # 内网-h7
# "http://61.145.136.61:7777", # 外网 # "http://61.145.136.61:7777", # 外网
# "http://61.145.136.61:7779", # 外网 # "http://61.145.136.61:7779", # 外网
] ]
# ---------------------------
def df_to_json_records(df: pd.DataFrame) -> list:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean = df.copy()
# 1️⃣ 替换 ±Inf -> NaN
num_cols = df_clean.select_dtypes(include=[np.number]).columns
if len(num_cols):
df_clean[num_cols] = df_clean[num_cols].replace([np.inf, -np.inf], np.nan)
# 2️⃣ 替换 NaN -> None(注意:有时 astype(object) 不彻底,需用 applymap)
df_clean = df_clean.applymap(lambda x: None if pd.isna(x) else x)
# 3️⃣ 转为 dict records
return df_clean.to_dict("records")
def clean_json_field_for_orjson(v):
"""清洗单个 JSON 字段的值,使其符合 orjson 要求并避免空字典入库"""
if v is None or pd.isna(v):
return None
# 1️⃣ 如果是空字典对象,返回 None
if isinstance(v, dict) and not v:
return None
# 2️⃣ 如果是空字符串或仅为 "{}",返回 None
if isinstance(v, str):
stripped = v.strip()
if not stripped or stripped == "{}":
return None
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict) and not parsed:
return None
return json.dumps(parsed, ensure_ascii=False)
except Exception:
return v # 非 JSON 字符串则原样保留
return v
def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
"""全面清洗 DataFrame 以符合 orjson 要求"""
df = df.replace([np.inf, -np.inf], np.nan)
df = df.applymap(lambda x: None if pd.isna(x) else x)
# 找出所有可能为 JSON 字符串的字段
json_like_cols = [col for col in df.columns if col.endswith('_json')]
# 针对每个 JSON-like 字段,应用清洗函数
for col in json_like_cols:
df[col] = df[col].apply(clean_json_field_for_orjson)
return df # ---------------------------
class RemoteTransaction: class RemoteTransaction:
def __init__(self, db: str, database: str, def __init__(self, db: str, database: str, session: requests.Session, urls: List[str]):
session: requests.Session, urls: List[str]):
self.db = db self.db = db
self.database = database self.database = database
self.session = session self.session = session
self.urls = urls self.urls = urls
self.sql_queue = [] self.sql_queue = []
# def execute(self, sql: str):
# self.sql_queue.append(sql)
def execute(self, sql: str, params=None): def execute(self, sql: str, params=None):
""" """
params 可取: params 可取:
...@@ -118,7 +64,8 @@ class RemoteTransaction: ...@@ -118,7 +64,8 @@ class RemoteTransaction:
""" """
self.sql_queue.append({"sql": sql, "params": params}) self.sql_queue.append({"sql": sql, "params": params})
def __enter__(self): return self def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
for url in self.urls: for url in self.urls:
...@@ -137,13 +84,13 @@ class RemoteTransaction: ...@@ -137,13 +84,13 @@ class RemoteTransaction:
class RemoteEngine: class RemoteEngine:
def __init__(self, db: str, database: str, def __init__(self, db: str, database: str, server_urls: List[str], retries: int = 2, alert_users: Optional[List[str]] = None):
server_urls: List[str], retries: int = 2):
self.db = db self.db = db
self.database = database self.database = database
self.urls = [u.rstrip("/") for u in server_urls] self.urls = [u.rstrip("/") for u in server_urls]
self.session = requests.Session() self.session = requests.Session()
self.retries = retries self.retries = retries
self.alert_users = alert_users or NOTIFY_DEFAULT_USERS # ✅ 默认通知人
def _request(self, endpoint: str, payload, timeout=3000): def _request(self, endpoint: str, payload, timeout=3000):
for url in self.urls: for url in self.urls:
...@@ -164,6 +111,27 @@ class RemoteEngine: ...@@ -164,6 +111,27 @@ class RemoteEngine:
time.sleep(1) time.sleep(1)
raise RuntimeError(f"All servers failed for {endpoint}") raise RuntimeError(f"All servers failed for {endpoint}")
def _notify_textcard(self, title: str, content: str):
"""只用 textcard,失败不影响主流程"""
payload = {
"users": self.alert_users, # 也可不传,由服务端默认;这里显式传入
"title": title,
"content": content
}
for url in self.urls:
try:
self.session.post(f"{url}/notify/wx", json=payload, timeout=8)
break
except Exception as e:
print(f"[WARN] notify fail @ {url}: {e}")
@staticmethod
def _short(s: str, n: int = 300) -> str:
if not s:
return ""
s = str(s)
return s if len(s) <= n else s[-n:]
def read_sql(self, sql: str) -> pd.DataFrame: def read_sql(self, sql: str) -> pd.DataFrame:
data = self._request("query", data = self._request("query",
{"db": self.db, {"db": self.db,
...@@ -171,7 +139,9 @@ class RemoteEngine: ...@@ -171,7 +139,9 @@ class RemoteEngine:
"site_name": self.database}) "site_name": self.database})
return pd.DataFrame(data["result"]) return pd.DataFrame(data["result"])
def to_sql(self, df: pd.DataFrame, table: str, if_exists="append"): def to_sql(self,
df: pd.DataFrame,
table: str, if_exists="append"):
return self._request("insert", return self._request("insert",
{"db": self.db, {"db": self.db,
...@@ -212,11 +182,11 @@ class RemoteEngine: ...@@ -212,11 +182,11 @@ class RemoteEngine:
return df return df
def sqoop_raw_import(self, def sqoop_raw_import(self,
site_name: str, # site_name: str,
db_type: str, # db_type: str,
query: str, query: str,
hive_table: str, hive_table: str,
hdfs_path: str, # hdfs路径 hdfs_path: str = None, # hdfs路径
default_db: str = "big_data_selection", default_db: str = "big_data_selection",
partitions: dict = None, partitions: dict = None,
queue_name: str = "default", queue_name: str = "default",
...@@ -228,9 +198,12 @@ class RemoteEngine: ...@@ -228,9 +198,12 @@ class RemoteEngine:
dry_run: bool = False, dry_run: bool = False,
clean_hdfs: bool = True, # 是否删除hdfs路径, 默认删除 clean_hdfs: bool = True, # 是否删除hdfs路径, 默认删除
timeout_sec: int = 36000): timeout_sec: int = 36000):
print("site_name:", self.database, "db_type:", self.db)
body = { body = {
"site_name": site_name, # "site_name": site_name,
"db_type": db_type, # "db_type": db_type,
"site_name": self.database,
"db_type": self.db,
"default_db": default_db, # 👈 必须明确带上 "default_db": default_db, # 👈 必须明确带上
"hive_table": hive_table, "hive_table": hive_table,
"query": query, "query": query,
...@@ -240,47 +213,105 @@ class RemoteEngine: ...@@ -240,47 +213,105 @@ class RemoteEngine:
"split_by": split_by, "split_by": split_by,
"outdir": outdir, "outdir": outdir,
"sqoop_home": sqoop_home, "sqoop_home": sqoop_home,
"job_name": job_name, "job_name": job_name or f"sqoop_task---{hive_table}",
"dry_run": dry_run, "dry_run": dry_run,
"hdfs_path": hdfs_path, "hdfs_path": hdfs_path,
"clean_hdfs": clean_hdfs, "clean_hdfs": clean_hdfs,
"timeout_sec": timeout_sec "timeout_sec": timeout_sec
} }
# 导入前对变量进行校验 resp = self._request("sqoop/raw_import", body, timeout=timeout_sec)
if hive_table not in hdfs_path or len(hdfs_path) < 10 or len(hive_table) < 3:
raise ValueError( def _print_step(name, r):
f"❌ 导入前变量校验失败, 请检查以下情况: \n" print(f"\n===== {name} =====")
f"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}\n" if "cmd" in r and r["cmd"]:
f"2. hive_table长度<3: 传入长度{len(hive_table)}\n" print("CMD:", r["cmd"])
f"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确" if r.get("msg"): print("MSG:", r["msg"])
if r.get("stdout"): print("\nSTDOUT:\n", r["stdout"])
if r.get("stderr"): print("\nSTDERR:\n", r["stderr"])
print("OK:", r.get("ok"), " CODE:", r.get("code"))
sch = resp.get("schema_check_result", {}) or {}
icr = resp.get("import_check_result", {}) or {}
ccr = resp.get("count_check_result", {}) or {}
_print_step("SCHEMA CHECK", sch)
if sch.get("ok") is not True:
self._notify_textcard(
title="Schema 校验失败",
content=(
f"任务:{body.get('job_name')}\n"
f"表:{hive_table}\n"
f"code:{sch.get('code')}\n"
f"msg:{self._short(sch.get('msg'))}"
) )
# 导入前对变量进行校验
if m > 1 and split_by is None and split_by in query:
raise ValueError(
f"❌ 导入前变量校验失败, 请检查以下情况: \n"
f"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}\n"
f"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}\n"
) )
# 执行导入, 进行异常检验 raise RuntimeError("Schema 校验失败。详见上面日志。")
import_result = self._request("sqoop/raw_import", body, timeout=timeout_sec)
# 1. cmd执行日志信息 _print_step("SQOOP IMPORT", icr)
print(f"sqoop导入完整日志: {import_result['stdout']}") if icr.get("ok") is not True:
del import_result['stdout'] self._notify_textcard(
# 2. 检查导入后的hdfs路径是否成功 title="Sqoop 导入失败",
hdfs_path_exists_after = import_result['hdfs_path_exists_after'] content=(
if not hdfs_path_exists_after: f"任务:{body.get('job_name')}\n"
# del import_result['hdfs_path_exists_after'] f"表:{hive_table}\n"
raise ValueError( f"code:{icr.get('code')}\n"
f"❌ sqoop导入失败, 请检查返回结果: {import_result}" f"msg:{self._short(icr.get('msg'))}\n"
f"stderr:{self._short(icr.get('stderr'))}"
) )
else: )
print(f"√ sqoop导入成功: {hdfs_path_exists_after}") raise RuntimeError("Sqoop 导入失败。详见上面日志。")
return import_result
_print_step("COUNT CHECK", ccr)
if ccr.get("ok") is not True:
self._notify_textcard(
title="数量校验失败",
content=(
f"任务:{body.get('job_name')}\n"
f"表:{hive_table}\n"
f"db_count:{ccr.get('db_count')}\n"
f"hive_count:{ccr.get('hive_count')}\n"
f"msg:{self._short(ccr.get('msg'))}"
)
)
raise RuntimeError("数量校验失败。详见上面日志。")
print("\n✅ 全部步骤通过。")
# return resp
# # 导入前对变量进行校验
# if hive_table not in hdfs_path or len(hdfs_path) < 10 or len(hive_table) < 3:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}\n"
# f"2. hive_table长度<3: 传入长度{len(hive_table)}\n"
# f"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
# )
# # 导入前对变量进行校验
# if m > 1 and split_by is None and split_by in query:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}\n"
# f"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}\n"
# )
# # 执行导入, 进行异常检验
# import_result = self._request("sqoop/raw_import", body, timeout=timeout_sec)
# # 1. cmd执行日志信息
# print(f"sqoop导入完整日志: {import_result['stdout']}")
# del import_result['stdout']
# # 2. 检查导入后的hdfs路径是否成功
# hdfs_path_exists_after = import_result['hdfs_path_exists_after']
# if not hdfs_path_exists_after:
# # del import_result['hdfs_path_exists_after']
# raise ValueError(
# f"❌ sqoop导入失败, 请检查返回结果: {import_result}"
# )
# else:
# print(f"√ sqoop导入成功: {hdfs_path_exists_after}")
# return import_result
def begin(self): def begin(self):
return RemoteTransaction(self.db, self.database, return RemoteTransaction(self.db, self.database,
self.session, self.urls) self.session, self.urls)
# ---------------------------------
def get_remote_engine(site_name: str, db_type: str, def get_remote_engine(site_name: str, db_type: str,
...@@ -294,3 +325,60 @@ def get_remote_engine(site_name: str, db_type: str, ...@@ -294,3 +325,60 @@ def get_remote_engine(site_name: str, db_type: str,
database=site_name, database=site_name,
server_urls=servers or DEFAULT_SERVERS, server_urls=servers or DEFAULT_SERVERS,
) )
### 工具类函数
def df_to_json_records(df: pd.DataFrame) -> list:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean = df.copy()
# 1️⃣ 替换 ±Inf -> NaN
num_cols = df_clean.select_dtypes(include=[np.number]).columns
if len(num_cols):
df_clean[num_cols] = df_clean[num_cols].replace([np.inf, -np.inf], np.nan)
# 2️⃣ 替换 NaN -> None(注意:有时 astype(object) 不彻底,需用 applymap)
df_clean = df_clean.applymap(lambda x: None if pd.isna(x) else x)
# 3️⃣ 转为 dict records
return df_clean.to_dict("records")
def clean_json_field_for_orjson(v):
"""清洗单个 JSON 字段的值,使其符合 orjson 要求并避免空字典入库"""
if v is None or pd.isna(v):
return None
# 1️⃣ 如果是空字典对象,返回 None
if isinstance(v, dict) and not v:
return None
# 2️⃣ 如果是空字符串或仅为 "{}",返回 None
if isinstance(v, str):
stripped = v.strip()
if not stripped or stripped == "{}":
return None
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict) and not parsed:
return None
return json.dumps(parsed, ensure_ascii=False)
except Exception:
return v # 非 JSON 字符串则原样保留
return v
def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
"""全面清洗 DataFrame 以符合 orjson 要求"""
df = df.replace([np.inf, -np.inf], np.nan)
df = df.applymap(lambda x: None if pd.isna(x) else x)
# 找出所有可能为 JSON 字符串的字段
json_like_cols = [col for col in df.columns if col.endswith('_json')]
# 针对每个 JSON-like 字段,应用清洗函数
for col in json_like_cols:
df[col] = df[col].apply(clean_json_field_for_orjson)
return df
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment