Commit 6840c5f3 by fangxingjun

no message

parent ee461a97
......@@ -3,7 +3,9 @@ import json
import pandas as pd
import numpy as np
import orjson, requests, time
from typing import List
from typing import List, Optional
mima = 'xsfasg'
# -------- 映射字典 --------
site_name_db_dict = {
......@@ -27,13 +29,16 @@ db_type_alias_map = {
}
DEFAULT_SERVERS = [
# "http://192.168.200.210:7777", # 内网
# "http://192.168.200.210:7777", # 内网-h5
"http://192.168.200.210:7778", # 内网-测试大数据-h5
# "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.217:7777", # 内网-h7
"http://61.145.136.61:7777", # 外网
"http://61.145.136.61:7779", # 外网
# "http://61.145.136.61:7777", # 外网
# "http://61.145.136.61:7779", # 外网
]
# ---------------------------
def df_to_json_records(df: pd.DataFrame) -> list:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean = df.copy()
......@@ -140,7 +145,7 @@ class RemoteEngine:
self.session = requests.Session()
self.retries = retries
def _request(self, endpoint: str, payload):
def _request(self, endpoint: str, payload, timeout=3000):
for url in self.urls:
for _ in range(self.retries):
try:
......@@ -148,7 +153,7 @@ class RemoteEngine:
r = self.session.post(f"{url}/{endpoint}",
data=json_bytes,
headers={"Content-Type": "application/json"},
timeout=3000)
timeout=timeout)
# r = self.session.post(f"{url}/{endpoint}",
# json=payload, timeout=10)
......@@ -158,31 +163,7 @@ class RemoteEngine:
print(f"[WARN] {endpoint} fail @ {url}: {e}")
time.sleep(1)
raise RuntimeError(f"All servers failed for {endpoint}")
# def _request(self, endpoint: str, payload):
# # 用 orjson,“allow_nan” 会把 NaN/Inf 写成 null
# # json_bytes = orjson.dumps(payload,
# # option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_OMIT_MICROSECOND | orjson.OPT_ALLOW_INF_AND_NAN)
# json_bytes = orjson.dumps(
# payload,
# option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_ALLOW_INF_AND_NAN
# )
#
# headers = {"Content-Type": "application/json"}
#
# for url in self.urls:
# for _ in range(self.retries):
# try:
# r = self.session.post(f"{url}/{endpoint}",
# data=json_bytes, headers=headers,
# timeout=15)
# r.raise_for_status()
# return r.json()
# except Exception as e:
# print(f"[WARN] {endpoint} fail @ {url}: {e}")
# time.sleep(1)
# raise RuntimeError(f"All servers failed for {endpoint}")
# ---------- 公共 API ----------
def read_sql(self, sql: str) -> pd.DataFrame:
data = self._request("query",
{"db": self.db,
......@@ -224,8 +205,78 @@ class RemoteEngine:
resp = self._request("read_then_update", payload)
df = pd.DataFrame(resp["read_result"])
rows_updated = resp.get("rows_updated", 0)
if df.empty:
print(f"啥也没查到,就不做更新, read_sql: {select_sql}")
else:
print(f"更新 {rows_updated} 行")
return df
def sqoop_raw_import(self,
site_name: str,
db_type: str,
query: str,
hive_table: str,
hdfs_path: str, # hdfs路径
default_db: str = "big_data_selection",
partitions: dict = None,
queue_name: str = "default",
m: int = 1,
split_by: Optional[str] = None, # 如果多map_num时, 必须要指定split_by字段
outdir: str = "/tmp/sqoop/",
sqoop_home: str = "/opt/module/sqoop-1.4.6/bin/sqoop",
job_name: str = None,
dry_run: bool = False,
clean_hdfs: bool = True, # 是否删除hdfs路径, 默认删除
timeout_sec: int = 36000):
body = {
"site_name": site_name,
"db_type": db_type,
"default_db": default_db, # 👈 必须明确带上
"hive_table": hive_table,
"query": query,
"partitions": partitions or {},
"queue_name": queue_name,
"m": m,
"split_by": split_by,
"outdir": outdir,
"sqoop_home": sqoop_home,
"job_name": job_name,
"dry_run": dry_run,
"hdfs_path": hdfs_path,
"clean_hdfs": clean_hdfs,
"timeout_sec": timeout_sec
}
# 导入前对变量进行校验
if hive_table not in hdfs_path or len(hdfs_path) < 10 or len(hive_table) < 3:
raise ValueError(
f"❌ 导入前变量校验失败, 请检查以下情况: \n"
f"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}\n"
f"2. hive_table长度<3: 传入长度{len(hive_table)}\n"
f"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
)
# 导入前对变量进行校验
if m > 1 and split_by is None and split_by in query:
raise ValueError(
f"❌ 导入前变量校验失败, 请检查以下情况: \n"
f"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}\n"
f"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}\n"
)
# 执行导入, 进行异常检验
import_result = self._request("sqoop/raw_import", body, timeout=timeout_sec)
# 1. cmd执行日志信息
print(f"sqoop导入完整日志: {import_result['stdout']}")
del import_result['stdout']
# 2. 检查导入后的hdfs路径是否成功
hdfs_path_exists_after = import_result['hdfs_path_exists_after']
if not hdfs_path_exists_after:
# del import_result['hdfs_path_exists_after']
raise ValueError(
f"❌ sqoop导入失败, 请检查返回结果: {import_result}"
)
else:
print(f"√ sqoop导入成功: {hdfs_path_exists_after}")
return import_result
def begin(self):
return RemoteTransaction(self.db, self.database,
self.session, self.urls)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment