Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
A
Amazon-Selection-Data
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
abel_cjy
Amazon-Selection-Data
Commits
6840c5f3
Commit
6840c5f3
authored
Aug 27, 2025
by
fangxingjun
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
no message
parent
ee461a97
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
31 deletions
+82
-31
secure_db_client.py
Pyspark_job/utils/secure_db_client.py
+82
-31
No files found.
Pyspark_job/utils/secure_db_client.py
View file @
6840c5f3
...
@@ -3,7 +3,9 @@ import json
...
@@ -3,7 +3,9 @@ import json
import
pandas
as
pd
import
pandas
as
pd
import
numpy
as
np
import
numpy
as
np
import
orjson
,
requests
,
time
import
orjson
,
requests
,
time
from
typing
import
List
from
typing
import
List
,
Optional
mima
=
'xsfasg'
# -------- 映射字典 --------
# -------- 映射字典 --------
site_name_db_dict
=
{
site_name_db_dict
=
{
...
@@ -27,13 +29,16 @@ db_type_alias_map = {
...
@@ -27,13 +29,16 @@ db_type_alias_map = {
}
}
DEFAULT_SERVERS
=
[
DEFAULT_SERVERS
=
[
# "http://192.168.200.210:7777", # 内网
# "http://192.168.200.210:7777", # 内网-h5
"http://192.168.200.210:7778"
,
# 内网-测试大数据-h5
# "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.217:7777", # 内网-h7
# "http://192.168.10.217:7777", # 内网-h7
"http://61.145.136.61:7777"
,
# 外网
#
"http://61.145.136.61:7777", # 外网
"http://61.145.136.61:7779"
,
# 外网
#
"http://61.145.136.61:7779", # 外网
]
]
# ---------------------------
# ---------------------------
def
df_to_json_records
(
df
:
pd
.
DataFrame
)
->
list
:
def
df_to_json_records
(
df
:
pd
.
DataFrame
)
->
list
:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean
=
df
.
copy
()
df_clean
=
df
.
copy
()
...
@@ -140,7 +145,7 @@ class RemoteEngine:
...
@@ -140,7 +145,7 @@ class RemoteEngine:
self
.
session
=
requests
.
Session
()
self
.
session
=
requests
.
Session
()
self
.
retries
=
retries
self
.
retries
=
retries
def
_request
(
self
,
endpoint
:
str
,
payload
):
def
_request
(
self
,
endpoint
:
str
,
payload
,
timeout
=
3000
):
for
url
in
self
.
urls
:
for
url
in
self
.
urls
:
for
_
in
range
(
self
.
retries
):
for
_
in
range
(
self
.
retries
):
try
:
try
:
...
@@ -148,7 +153,7 @@ class RemoteEngine:
...
@@ -148,7 +153,7 @@ class RemoteEngine:
r
=
self
.
session
.
post
(
f
"{url}/{endpoint}"
,
r
=
self
.
session
.
post
(
f
"{url}/{endpoint}"
,
data
=
json_bytes
,
data
=
json_bytes
,
headers
=
{
"Content-Type"
:
"application/json"
},
headers
=
{
"Content-Type"
:
"application/json"
},
timeout
=
3000
)
timeout
=
timeout
)
# r = self.session.post(f"{url}/{endpoint}",
# r = self.session.post(f"{url}/{endpoint}",
# json=payload, timeout=10)
# json=payload, timeout=10)
...
@@ -158,31 +163,7 @@ class RemoteEngine:
...
@@ -158,31 +163,7 @@ class RemoteEngine:
print
(
f
"[WARN] {endpoint} fail @ {url}: {e}"
)
print
(
f
"[WARN] {endpoint} fail @ {url}: {e}"
)
time
.
sleep
(
1
)
time
.
sleep
(
1
)
raise
RuntimeError
(
f
"All servers failed for {endpoint}"
)
raise
RuntimeError
(
f
"All servers failed for {endpoint}"
)
# def _request(self, endpoint: str, payload):
# # 用 orjson,“allow_nan” 会把 NaN/Inf 写成 null
# # json_bytes = orjson.dumps(payload,
# # option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_OMIT_MICROSECOND | orjson.OPT_ALLOW_INF_AND_NAN)
# json_bytes = orjson.dumps(
# payload,
# option=orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC | orjson.OPT_ALLOW_INF_AND_NAN
# )
#
# headers = {"Content-Type": "application/json"}
#
# for url in self.urls:
# for _ in range(self.retries):
# try:
# r = self.session.post(f"{url}/{endpoint}",
# data=json_bytes, headers=headers,
# timeout=15)
# r.raise_for_status()
# return r.json()
# except Exception as e:
# print(f"[WARN] {endpoint} fail @ {url}: {e}")
# time.sleep(1)
# raise RuntimeError(f"All servers failed for {endpoint}")
# ---------- 公共 API ----------
def
read_sql
(
self
,
sql
:
str
)
->
pd
.
DataFrame
:
def
read_sql
(
self
,
sql
:
str
)
->
pd
.
DataFrame
:
data
=
self
.
_request
(
"query"
,
data
=
self
.
_request
(
"query"
,
{
"db"
:
self
.
db
,
{
"db"
:
self
.
db
,
...
@@ -224,8 +205,78 @@ class RemoteEngine:
...
@@ -224,8 +205,78 @@ class RemoteEngine:
resp
=
self
.
_request
(
"read_then_update"
,
payload
)
resp
=
self
.
_request
(
"read_then_update"
,
payload
)
df
=
pd
.
DataFrame
(
resp
[
"read_result"
])
df
=
pd
.
DataFrame
(
resp
[
"read_result"
])
rows_updated
=
resp
.
get
(
"rows_updated"
,
0
)
rows_updated
=
resp
.
get
(
"rows_updated"
,
0
)
if
df
.
empty
:
print
(
f
"啥也没查到,就不做更新, read_sql: {select_sql}"
)
else
:
print
(
f
"更新 {rows_updated} 行"
)
return
df
return
df
def
sqoop_raw_import
(
self
,
site_name
:
str
,
db_type
:
str
,
query
:
str
,
hive_table
:
str
,
hdfs_path
:
str
,
# hdfs路径
default_db
:
str
=
"big_data_selection"
,
partitions
:
dict
=
None
,
queue_name
:
str
=
"default"
,
m
:
int
=
1
,
split_by
:
Optional
[
str
]
=
None
,
# 如果多map_num时, 必须要指定split_by字段
outdir
:
str
=
"/tmp/sqoop/"
,
sqoop_home
:
str
=
"/opt/module/sqoop-1.4.6/bin/sqoop"
,
job_name
:
str
=
None
,
dry_run
:
bool
=
False
,
clean_hdfs
:
bool
=
True
,
# 是否删除hdfs路径, 默认删除
timeout_sec
:
int
=
36000
):
body
=
{
"site_name"
:
site_name
,
"db_type"
:
db_type
,
"default_db"
:
default_db
,
# 👈 必须明确带上
"hive_table"
:
hive_table
,
"query"
:
query
,
"partitions"
:
partitions
or
{},
"queue_name"
:
queue_name
,
"m"
:
m
,
"split_by"
:
split_by
,
"outdir"
:
outdir
,
"sqoop_home"
:
sqoop_home
,
"job_name"
:
job_name
,
"dry_run"
:
dry_run
,
"hdfs_path"
:
hdfs_path
,
"clean_hdfs"
:
clean_hdfs
,
"timeout_sec"
:
timeout_sec
}
# 导入前对变量进行校验
if
hive_table
not
in
hdfs_path
or
len
(
hdfs_path
)
<
10
or
len
(
hive_table
)
<
3
:
raise
ValueError
(
f
"❌ 导入前变量校验失败, 请检查以下情况:
\n
"
f
"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}
\n
"
f
"2. hive_table长度<3: 传入长度{len(hive_table)}
\n
"
f
"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
)
# 导入前对变量进行校验
if
m
>
1
and
split_by
is
None
and
split_by
in
query
:
raise
ValueError
(
f
"❌ 导入前变量校验失败, 请检查以下情况:
\n
"
f
"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}
\n
"
f
"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}
\n
"
)
# 执行导入, 进行异常检验
import_result
=
self
.
_request
(
"sqoop/raw_import"
,
body
,
timeout
=
timeout_sec
)
# 1. cmd执行日志信息
print
(
f
"sqoop导入完整日志: {import_result['stdout']}"
)
del
import_result
[
'stdout'
]
# 2. 检查导入后的hdfs路径是否成功
hdfs_path_exists_after
=
import_result
[
'hdfs_path_exists_after'
]
if
not
hdfs_path_exists_after
:
# del import_result['hdfs_path_exists_after']
raise
ValueError
(
f
"❌ sqoop导入失败, 请检查返回结果: {import_result}"
)
else
:
print
(
f
"√ sqoop导入成功: {hdfs_path_exists_after}"
)
return
import_result
def
begin
(
self
):
def
begin
(
self
):
return
RemoteTransaction
(
self
.
db
,
self
.
database
,
return
RemoteTransaction
(
self
.
db
,
self
.
database
,
self
.
session
,
self
.
urls
)
self
.
session
,
self
.
urls
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment