Commit 0feb4e27 by fangxingjun

no message

parent d75ed4cd
......@@ -9,8 +9,6 @@ from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# 导入udf公共方法
from yswg_utils.common_udf import udf_parse_bs_category
# from ..yswg_utils.common_udf import udf_parse_bs_category
from utils.spark_util import SparkUtil
from utils.hdfs_utils import HdfsUtils
......
import os
import sys
import pandas as pd
import numpy as np
import json
import orjson, requests, time
from typing import List, Optional
from typing import List, Optional, Union
mima = 'xsfasg'
DEFAULT_USER = "fangxingjun"
DEFAULT_USER_TOKEN = "5f1b2e9c3a4d7f60"
# sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
# try:
# from user import DEFAULT_USER, DEFAULT_USER_TOKEN
# except Exception as e:
# from .user import DEFAULT_USER, DEFAULT_USER_TOKEN
# 新增:默认微信通知人(用于任务失败时告警)
NOTIFY_DEFAULT_USERS = ['fangxingjun', 'chenyuanjie']
NOTIFY_DEFAULT_USERS = ['pengyanbing', 'hezhe', 'chenyuanjie', 'fangxingjun']
# -------- 映射字典 --------
site_name_db_dict = {
......@@ -33,11 +42,11 @@ db_type_alias_map = {
DEFAULT_SERVERS = [
# "http://192.168.200.210:7777", # 内网-h5
# "http://192.168.200.210:7778", # 内网-测试大数据-h5
"http://192.168.10.216:7777", # 内网-测试大数据-h6
# "http://192.168.10.216:7777", # 内网-测试大数据-h6
# "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.217:7777", # 内网-h7
# "http://61.145.136.61:7777", # 外网
# "http://61.145.136.61:7779", # 外网
"http://61.145.136.61:7779", # 外网
]
......@@ -46,12 +55,15 @@ DEFAULT_SERVERS = [
class RemoteTransaction:
def __init__(self, db: str, database: str, session: requests.Session, urls: List[str]):
self.db = db
self.database = database
def __init__(self, db_type: str, site_name: str, session: requests.Session, urls: List[str],
user: Optional[str] = None, user_token: Optional[str] = None):
self.db_type = db_type
self.site_name = site_name
self.session = session
self.urls = urls
self.sql_queue = []
self.user = user if user is not None else DEFAULT_USER
self.user_token = user_token if user_token is not None else DEFAULT_USER_TOKEN
def execute(self, sql: str, params=None):
"""
......@@ -72,9 +84,12 @@ class RemoteTransaction:
try:
self.session.post(
url + "/transaction",
json={"db": self.db,
json={"db": self.db_type,
"sql_list": self.sql_queue,
"site_name": self.database}, # site_name not needed on server, kept for clarity
"site_name": self.site_name, # site_name not needed on server, kept for clarity
"user": self.user, # ✅ 携带
"user_token": self.user_token, # ✅ 携带
},
timeout=3000,
).raise_for_status()
return
......@@ -84,15 +99,25 @@ class RemoteTransaction:
class RemoteEngine:
def __init__(self, db: str, database: str, server_urls: List[str], retries: int = 2, alert_users: Optional[List[str]] = None):
self.db = db
self.database = database
def __init__(self, db_type: str, site_name: str, server_urls: List[str], retries: int = 2,
alert_users: Optional[List[str]] = None,
user: Optional[str] = None, user_token: Optional[str] = None):
self.db_type = db_type
self.site_name = site_name
self.urls = [u.rstrip("/") for u in server_urls]
self.session = requests.Session()
self.retries = retries
self.alert_users = alert_users or NOTIFY_DEFAULT_USERS # ✅ 默认通知人
self.user = user
self.user_token = user_token
if not self.user or not self.user_token:
raise ValueError("user and user_token are required (or set env MDB_USER / MDB_USER_TOKEN)")
def _request(self, endpoint: str, payload, timeout=3000):
# 统一注入 user / user_token
if isinstance(payload, dict):
payload.setdefault("user", self.user)
payload.setdefault("user_token", self.user_token)
for url in self.urls:
for _ in range(self.retries):
try:
......@@ -111,6 +136,69 @@ class RemoteEngine:
time.sleep(1)
raise RuntimeError(f"All servers failed for {endpoint}")
# def _request(self, endpoint: str, payload, timeout: int = 3000):
# """
# 统一注入 user / user_token,并对非 200 响应解析 detail/error 文本,
# 让调用方能看到“凭证错误/未授权”等服务端具体原因。
# """
# # 注入 user / token
# if isinstance(payload, dict):
# payload.setdefault("user", self.user)
# payload.setdefault("user_token", self.user_token)
#
# # 可能有多个服务端,逐个尝试
# import json as orjson
# json_bytes = orjson.dumps(payload)
#
# errors = [] # 记录每个 URL 的错误摘要
# last_exc = None
#
# for url in self.urls:
# try:
# r = self.session.post(
# f"{url}/{endpoint}",
# data=json_bytes,
# headers={"Content-Type": "application/json"},
# timeout=timeout / 1000.0 if timeout and timeout > 100 else timeout
# )
#
# # 200:一定是 JSON
# if r.status_code == 200:
# return r.json()
#
# # 非 200:尽量从 JSON 里抽取服务端给的 detail / error
# msg = None
# try:
# j = r.json()
# msg = j.get("detail") or j.get("error") or j.get("message") or str(j)
# except Exception:
# # 不是 JSON 就用纯文本,避免遮蔽真实报错
# msg = (r.text or "").strip()
#
# brief_sql = ""
# if isinstance(payload, dict) and payload.get("sql"):
# s = str(payload["sql"]).strip().replace("\n", " ")
# brief_sql = f" sql={s[:200]}{'...' if len(s) > 200 else ''}"
#
# warn = f"[WARN] {endpoint} fail @ {url}: {r.status_code} {msg}{brief_sql}"
# print(warn)
# errors.append(warn)
# continue
#
# except Exception as e:
# last_exc = e
# warn = f"[WARN] {endpoint} fail @ {url}: {e}"
# print(warn)
# errors.append(warn)
# continue
#
# # 全部失败:把最后一次/全部错误拼到异常里,便于定位
# if errors:
# raise RuntimeError(" ; ".join(errors))
# if last_exc:
# raise RuntimeError(str(last_exc))
# raise RuntimeError(f"All servers failed for {endpoint}")
def _notify_textcard(self, title: str, content: str):
"""只用 textcard,失败不影响主流程"""
payload = {
......@@ -134,9 +222,9 @@ class RemoteEngine:
def read_sql(self, sql: str) -> pd.DataFrame:
data = self._request("query",
{"db": self.db,
{"db": self.db_type,
"sql": sql,
"site_name": self.database})
"site_name": self.site_name})
return pd.DataFrame(data["result"])
def to_sql(self,
......@@ -144,12 +232,12 @@ class RemoteEngine:
table: str, if_exists="append"):
return self._request("insert",
{"db": self.db,
{"db": self.db_type,
"table": table,
"if_exists": if_exists,
"data": fully_clean_for_orjson(df=df).to_dict("records"),
# "data": df_to_json_records(df), # ← 清洗后的 records
"site_name": self.database})
"site_name": self.site_name})
def read_then_update(
self,
......@@ -164,8 +252,8 @@ class RemoteEngine:
返回 (DataFrame, rows_updated)
"""
payload = {
"db": self.db,
"site_name": self.database,
"db": self.db_type,
"site_name": self.site_name,
"select_sql": select_sql,
"update_table": update_table,
"set_values": set_values,
......@@ -198,12 +286,12 @@ class RemoteEngine:
dry_run: bool = False,
clean_hdfs: bool = True, # 是否删除hdfs路径, 默认删除
timeout_sec: int = 36000):
print("site_name:", self.database, "db_type:", self.db)
print("site_name:", self.site_name, "db_type:", self.db_type)
body = {
# "site_name": site_name,
# "db_type": db_type,
"site_name": self.database,
"db_type": self.db,
"site_name": self.site_name,
"db_type": self.db_type,
"default_db": default_db, # 👈 必须明确带上
"hive_table": hive_table,
"query": query,
......@@ -276,54 +364,71 @@ class RemoteEngine:
raise RuntimeError("数量校验失败。详见上面日志。")
print("\n✅ 全部步骤通过。")
# return resp
# # 导入前对变量进行校验
# if hive_table not in hdfs_path or len(hdfs_path) < 10 or len(hive_table) < 3:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}\n"
# f"2. hive_table长度<3: 传入长度{len(hive_table)}\n"
# f"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
# )
# # 导入前对变量进行校验
# if m > 1 and split_by is None and split_by in query:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}\n"
# f"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}\n"
# )
# # 执行导入, 进行异常检验
# import_result = self._request("sqoop/raw_import", body, timeout=timeout_sec)
# # 1. cmd执行日志信息
# print(f"sqoop导入完整日志: {import_result['stdout']}")
# del import_result['stdout']
# # 2. 检查导入后的hdfs路径是否成功
# hdfs_path_exists_after = import_result['hdfs_path_exists_after']
# if not hdfs_path_exists_after:
# # del import_result['hdfs_path_exists_after']
# raise ValueError(
# f"❌ sqoop导入失败, 请检查返回结果: {import_result}"
# )
# else:
# print(f"√ sqoop导入成功: {hdfs_path_exists_after}")
# return import_result
def begin(self):
return RemoteTransaction(self.db, self.database,
self.session, self.urls)
return RemoteTransaction(self.db_type, self.site_name,
self.session, self.urls,
user=self.user, user_token=self.user_token)
def execute(self, sql: str, params: Optional[Union[dict, list, tuple]] = None, as_df: bool = True):
"""
执行任意 SQL。
- params=None → 纯文本执行
- params=dict → 单条参数化执行
- params=list/tuple → executemany 批量执行(元素 tuple/list)
- param as_df: 是否返回 DataFrame(仅对 SELECT 有效), 默认为True
返回服务端原始结构:
成功:{"ok": True, ...}
失败:{"ok": False, "error": "...", "trace": "...", "sql": "...", "params": ...}
"""
payload = {
"db": self.db_type, # 注意用你类里对应字段名
"site_name": self.site_name,
"sql": sql,
"params": params,
}
resp = self._request("execute", payload)
# 服务端保证总是 200 + JSON,这里只需判断 ok
if not resp.get("ok", False):
# 把错误详情展示给使用者
print("\n===== EXECUTE ERROR =====")
print("SQL:", resp.get("sql"))
print("PARAMS:", resp.get("params"))
print("ERROR:", resp.get("error"))
if resp.get("trace"):
print("\nTRACE (tail):\n", resp["trace"])
# 你也可以在这里发 textcard 通知
raise RuntimeError(resp.get("error") or "execute failed")
# ✅ 如果是 SELECT,并且 as_df=True,则直接返回 DataFrame
if as_df and resp.get("returns_rows") and "result" in resp:
return pd.DataFrame(resp["result"])
def get_remote_engine(site_name: str, db_type: str,
servers: List[str] = None) -> RemoteEngine:
return resp
def get_remote_engine(site_name: str, db_type: str, servers: List[str] = None,
user: Optional[str] = None, user_token: Optional[str] = None) -> RemoteEngine:
if site_name not in site_name_db_dict:
raise ValueError(f"Unknown site_name: {site_name}")
if db_type not in db_type_alias_map:
raise ValueError(f"Unknown db_type: {db_type}")
# 如果调用方没传,则用文件顶部默认值
user = user if user is not None else DEFAULT_USER
user_token = user_token if user_token is not None else DEFAULT_USER_TOKEN
# 这里做一次最小校验,避免忘配
if not user or not user_token:
raise ValueError(
"user and user_token are required (set in file top DEFAULT_USER/DEFAULT_USER_TOKEN, or pass explicitly)")
return RemoteEngine(
db=db_type_alias_map[db_type],
database=site_name,
db_type=db_type_alias_map[db_type],
site_name=site_name,
server_urls=servers or DEFAULT_SERVERS,
user=user, user_token=user_token, # ✅
)
......@@ -369,10 +474,38 @@ def clean_json_field_for_orjson(v):
return v
# def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
# """全面清洗 DataFrame 以符合 orjson 要求"""
# df = df.replace([np.inf, -np.inf], np.nan)
# df = df.applymap(lambda x: None if pd.isna(x) else x)
#
# # 找出所有可能为 JSON 字符串的字段
# json_like_cols = [col for col in df.columns if col.endswith('_json')]
#
# # 针对每个 JSON-like 字段,应用清洗函数
# for col in json_like_cols:
# df[col] = df[col].apply(clean_json_field_for_orjson)
#
# return df
def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
# """全面清洗 DataFrame 以符合 orjson 要求"""
# df = df.replace([np.inf, -np.inf], np.nan)
# df = df.applymap(lambda x: None if pd.isna(x) else x)
#
# # 找出所有可能为 JSON 字符串的字段
# json_like_cols = [col for col in df.columns if col.endswith('_json')]
#
# # 针对每个 JSON-like 字段,应用清洗函数
# for col in json_like_cols:
# df[col] = df[col].apply(clean_json_field_for_orjson)
#
# return df
"""全面清洗 DataFrame 以符合 orjson 要求"""
df = df.replace([np.inf, -np.inf], np.nan)
df = df.applymap(lambda x: None if pd.isna(x) else x)
# NaN → None (比 applymap 高效且不出错)
df = df.where(pd.notna(df), None)
# 找出所有可能为 JSON 字符串的字段
json_like_cols = [col for col in df.columns if col.endswith('_json')]
......@@ -381,4 +514,4 @@ def fully_clean_for_orjson(df: pd.DataFrame) -> pd.DataFrame:
for col in json_like_cols:
df[col] = df[col].apply(clean_json_field_for_orjson)
return df
return df
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment