Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
A
Amazon-Selection-Data
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
abel_cjy
Amazon-Selection-Data
Commits
7293af95
Commit
7293af95
authored
Sep 10, 2025
by
fangxingjun
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
no message
parent
dbb59488
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
188 additions
and
100 deletions
+188
-100
secure_db_client.py
Pyspark_job/utils/secure_db_client.py
+188
-100
No files found.
Pyspark_job/utils/secure_db_client.py
View file @
7293af95
import
json
import
pandas
as
pd
import
pandas
as
pd
import
numpy
as
np
import
numpy
as
np
import
json
import
orjson
,
requests
,
time
import
orjson
,
requests
,
time
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
mima
=
'xsfasg'
mima
=
'xsfasg'
# 新增:默认微信通知人(用于任务失败时告警)
NOTIFY_DEFAULT_USERS
=
[
'fangxingjun'
,
'chenyuanjie'
]
# -------- 映射字典 --------
# -------- 映射字典 --------
site_name_db_dict
=
{
site_name_db_dict
=
{
"us"
:
"selection"
,
"us"
:
"selection"
,
...
@@ -30,83 +32,27 @@ db_type_alias_map = {
...
@@ -30,83 +32,27 @@ db_type_alias_map = {
DEFAULT_SERVERS
=
[
DEFAULT_SERVERS
=
[
# "http://192.168.200.210:7777", # 内网-h5
# "http://192.168.200.210:7777", # 内网-h5
"http://192.168.200.210:7778"
,
# 内网-测试大数据-h5
# "http://192.168.200.210:7778", # 内网-测试大数据-h5
"http://192.168.10.216:7777"
,
# 内网-测试大数据-h6
# "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.216:7777", # 内网-h6
# "http://192.168.10.217:7777", # 内网-h7
# "http://192.168.10.217:7777", # 内网-h7
# "http://61.145.136.61:7777", # 外网
# "http://61.145.136.61:7777", # 外网
# "http://61.145.136.61:7779", # 外网
# "http://61.145.136.61:7779", # 外网
]
]
# ---------------------------
def
df_to_json_records
(
df
:
pd
.
DataFrame
)
->
list
:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean
=
df
.
copy
()
# 1️⃣ 替换 ±Inf -> NaN
num_cols
=
df_clean
.
select_dtypes
(
include
=
[
np
.
number
])
.
columns
if
len
(
num_cols
):
df_clean
[
num_cols
]
=
df_clean
[
num_cols
]
.
replace
([
np
.
inf
,
-
np
.
inf
],
np
.
nan
)
# 2️⃣ 替换 NaN -> None(注意:有时 astype(object) 不彻底,需用 applymap)
df_clean
=
df_clean
.
applymap
(
lambda
x
:
None
if
pd
.
isna
(
x
)
else
x
)
# 3️⃣ 转为 dict records
return
df_clean
.
to_dict
(
"records"
)
def
clean_json_field_for_orjson
(
v
):
"""清洗单个 JSON 字段的值,使其符合 orjson 要求并避免空字典入库"""
if
v
is
None
or
pd
.
isna
(
v
):
return
None
# 1️⃣ 如果是空字典对象,返回 None
if
isinstance
(
v
,
dict
)
and
not
v
:
return
None
# 2️⃣ 如果是空字符串或仅为 "{}",返回 None
# ---------------------------
if
isinstance
(
v
,
str
):
stripped
=
v
.
strip
()
if
not
stripped
or
stripped
==
"{}"
:
return
None
try
:
parsed
=
json
.
loads
(
stripped
)
if
isinstance
(
parsed
,
dict
)
and
not
parsed
:
return
None
return
json
.
dumps
(
parsed
,
ensure_ascii
=
False
)
except
Exception
:
return
v
# 非 JSON 字符串则原样保留
return
v
def
fully_clean_for_orjson
(
df
:
pd
.
DataFrame
)
->
pd
.
DataFrame
:
"""全面清洗 DataFrame 以符合 orjson 要求"""
df
=
df
.
replace
([
np
.
inf
,
-
np
.
inf
],
np
.
nan
)
df
=
df
.
applymap
(
lambda
x
:
None
if
pd
.
isna
(
x
)
else
x
)
# 找出所有可能为 JSON 字符串的字段
json_like_cols
=
[
col
for
col
in
df
.
columns
if
col
.
endswith
(
'_json'
)]
# 针对每个 JSON-like 字段,应用清洗函数
for
col
in
json_like_cols
:
df
[
col
]
=
df
[
col
]
.
apply
(
clean_json_field_for_orjson
)
return
df
class
RemoteTransaction
:
class
RemoteTransaction
:
def
__init__
(
self
,
db
:
str
,
database
:
str
,
def
__init__
(
self
,
db
:
str
,
database
:
str
,
session
:
requests
.
Session
,
urls
:
List
[
str
]):
session
:
requests
.
Session
,
urls
:
List
[
str
]):
self
.
db
=
db
self
.
db
=
db
self
.
database
=
database
self
.
database
=
database
self
.
session
=
session
self
.
session
=
session
self
.
urls
=
urls
self
.
urls
=
urls
self
.
sql_queue
=
[]
self
.
sql_queue
=
[]
# def execute(self, sql: str):
# self.sql_queue.append(sql)
def
execute
(
self
,
sql
:
str
,
params
=
None
):
def
execute
(
self
,
sql
:
str
,
params
=
None
):
"""
"""
params 可取:
params 可取:
...
@@ -118,7 +64,8 @@ class RemoteTransaction:
...
@@ -118,7 +64,8 @@ class RemoteTransaction:
"""
"""
self
.
sql_queue
.
append
({
"sql"
:
sql
,
"params"
:
params
})
self
.
sql_queue
.
append
({
"sql"
:
sql
,
"params"
:
params
})
def
__enter__
(
self
):
return
self
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc
,
tb
):
def
__exit__
(
self
,
exc_type
,
exc
,
tb
):
for
url
in
self
.
urls
:
for
url
in
self
.
urls
:
...
@@ -137,13 +84,13 @@ class RemoteTransaction:
...
@@ -137,13 +84,13 @@ class RemoteTransaction:
class
RemoteEngine
:
class
RemoteEngine
:
def
__init__
(
self
,
db
:
str
,
database
:
str
,
def
__init__
(
self
,
db
:
str
,
database
:
str
,
server_urls
:
List
[
str
],
retries
:
int
=
2
,
alert_users
:
Optional
[
List
[
str
]]
=
None
):
server_urls
:
List
[
str
],
retries
:
int
=
2
):
self
.
db
=
db
self
.
db
=
db
self
.
database
=
database
self
.
database
=
database
self
.
urls
=
[
u
.
rstrip
(
"/"
)
for
u
in
server_urls
]
self
.
urls
=
[
u
.
rstrip
(
"/"
)
for
u
in
server_urls
]
self
.
session
=
requests
.
Session
()
self
.
session
=
requests
.
Session
()
self
.
retries
=
retries
self
.
retries
=
retries
self
.
alert_users
=
alert_users
or
NOTIFY_DEFAULT_USERS
# ✅ 默认通知人
def
_request
(
self
,
endpoint
:
str
,
payload
,
timeout
=
3000
):
def
_request
(
self
,
endpoint
:
str
,
payload
,
timeout
=
3000
):
for
url
in
self
.
urls
:
for
url
in
self
.
urls
:
...
@@ -164,6 +111,27 @@ class RemoteEngine:
...
@@ -164,6 +111,27 @@ class RemoteEngine:
time
.
sleep
(
1
)
time
.
sleep
(
1
)
raise
RuntimeError
(
f
"All servers failed for {endpoint}"
)
raise
RuntimeError
(
f
"All servers failed for {endpoint}"
)
def
_notify_textcard
(
self
,
title
:
str
,
content
:
str
):
"""只用 textcard,失败不影响主流程"""
payload
=
{
"users"
:
self
.
alert_users
,
# 也可不传,由服务端默认;这里显式传入
"title"
:
title
,
"content"
:
content
}
for
url
in
self
.
urls
:
try
:
self
.
session
.
post
(
f
"{url}/notify/wx"
,
json
=
payload
,
timeout
=
8
)
break
except
Exception
as
e
:
print
(
f
"[WARN] notify fail @ {url}: {e}"
)
@staticmethod
def
_short
(
s
:
str
,
n
:
int
=
300
)
->
str
:
if
not
s
:
return
""
s
=
str
(
s
)
return
s
if
len
(
s
)
<=
n
else
s
[
-
n
:]
def
read_sql
(
self
,
sql
:
str
)
->
pd
.
DataFrame
:
def
read_sql
(
self
,
sql
:
str
)
->
pd
.
DataFrame
:
data
=
self
.
_request
(
"query"
,
data
=
self
.
_request
(
"query"
,
{
"db"
:
self
.
db
,
{
"db"
:
self
.
db
,
...
@@ -171,7 +139,9 @@ class RemoteEngine:
...
@@ -171,7 +139,9 @@ class RemoteEngine:
"site_name"
:
self
.
database
})
"site_name"
:
self
.
database
})
return
pd
.
DataFrame
(
data
[
"result"
])
return
pd
.
DataFrame
(
data
[
"result"
])
def
to_sql
(
self
,
df
:
pd
.
DataFrame
,
table
:
str
,
if_exists
=
"append"
):
def
to_sql
(
self
,
df
:
pd
.
DataFrame
,
table
:
str
,
if_exists
=
"append"
):
return
self
.
_request
(
"insert"
,
return
self
.
_request
(
"insert"
,
{
"db"
:
self
.
db
,
{
"db"
:
self
.
db
,
...
@@ -212,11 +182,11 @@ class RemoteEngine:
...
@@ -212,11 +182,11 @@ class RemoteEngine:
return
df
return
df
def
sqoop_raw_import
(
self
,
def
sqoop_raw_import
(
self
,
site_name
:
str
,
#
site_name: str,
db_type
:
str
,
#
db_type: str,
query
:
str
,
query
:
str
,
hive_table
:
str
,
hive_table
:
str
,
hdfs_path
:
str
,
# hdfs路径
hdfs_path
:
str
=
None
,
# hdfs路径
default_db
:
str
=
"big_data_selection"
,
default_db
:
str
=
"big_data_selection"
,
partitions
:
dict
=
None
,
partitions
:
dict
=
None
,
queue_name
:
str
=
"default"
,
queue_name
:
str
=
"default"
,
...
@@ -228,9 +198,12 @@ class RemoteEngine:
...
@@ -228,9 +198,12 @@ class RemoteEngine:
dry_run
:
bool
=
False
,
dry_run
:
bool
=
False
,
clean_hdfs
:
bool
=
True
,
# 是否删除hdfs路径, 默认删除
clean_hdfs
:
bool
=
True
,
# 是否删除hdfs路径, 默认删除
timeout_sec
:
int
=
36000
):
timeout_sec
:
int
=
36000
):
print
(
"site_name:"
,
self
.
database
,
"db_type:"
,
self
.
db
)
body
=
{
body
=
{
"site_name"
:
site_name
,
# "site_name": site_name,
"db_type"
:
db_type
,
# "db_type": db_type,
"site_name"
:
self
.
database
,
"db_type"
:
self
.
db
,
"default_db"
:
default_db
,
# 👈 必须明确带上
"default_db"
:
default_db
,
# 👈 必须明确带上
"hive_table"
:
hive_table
,
"hive_table"
:
hive_table
,
"query"
:
query
,
"query"
:
query
,
...
@@ -240,47 +213,105 @@ class RemoteEngine:
...
@@ -240,47 +213,105 @@ class RemoteEngine:
"split_by"
:
split_by
,
"split_by"
:
split_by
,
"outdir"
:
outdir
,
"outdir"
:
outdir
,
"sqoop_home"
:
sqoop_home
,
"sqoop_home"
:
sqoop_home
,
"job_name"
:
job_name
,
"job_name"
:
job_name
or
f
"sqoop_task---{hive_table}"
,
"dry_run"
:
dry_run
,
"dry_run"
:
dry_run
,
"hdfs_path"
:
hdfs_path
,
"hdfs_path"
:
hdfs_path
,
"clean_hdfs"
:
clean_hdfs
,
"clean_hdfs"
:
clean_hdfs
,
"timeout_sec"
:
timeout_sec
"timeout_sec"
:
timeout_sec
}
}
# 导入前对变量进行校验
resp
=
self
.
_request
(
"sqoop/raw_import"
,
body
,
timeout
=
timeout_sec
)
if
hive_table
not
in
hdfs_path
or
len
(
hdfs_path
)
<
10
or
len
(
hive_table
)
<
3
:
raise
ValueError
(
def
_print_step
(
name
,
r
):
f
"❌ 导入前变量校验失败, 请检查以下情况:
\n
"
print
(
f
"
\n
===== {name} ====="
)
f
"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}
\n
"
if
"cmd"
in
r
and
r
[
"cmd"
]:
f
"2. hive_table长度<3: 传入长度{len(hive_table)}
\n
"
print
(
"CMD:"
,
r
[
"cmd"
])
f
"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
if
r
.
get
(
"msg"
):
print
(
"MSG:"
,
r
[
"msg"
])
if
r
.
get
(
"stdout"
):
print
(
"
\n
STDOUT:
\n
"
,
r
[
"stdout"
])
if
r
.
get
(
"stderr"
):
print
(
"
\n
STDERR:
\n
"
,
r
[
"stderr"
])
print
(
"OK:"
,
r
.
get
(
"ok"
),
" CODE:"
,
r
.
get
(
"code"
))
sch
=
resp
.
get
(
"schema_check_result"
,
{})
or
{}
icr
=
resp
.
get
(
"import_check_result"
,
{})
or
{}
ccr
=
resp
.
get
(
"count_check_result"
,
{})
or
{}
_print_step
(
"SCHEMA CHECK"
,
sch
)
if
sch
.
get
(
"ok"
)
is
not
True
:
self
.
_notify_textcard
(
title
=
"Schema 校验失败"
,
content
=
(
f
"任务:{body.get('job_name')}
\n
"
f
"表:{hive_table}
\n
"
f
"code:{sch.get('code')}
\n
"
f
"msg:{self._short(sch.get('msg'))}"
)
)
)
# 导入前对变量进行校验
raise
RuntimeError
(
"Schema 校验失败。详见上面日志。"
)
if
m
>
1
and
split_by
is
None
and
split_by
in
query
:
raise
ValueError
(
_print_step
(
"SQOOP IMPORT"
,
icr
)
f
"❌ 导入前变量校验失败, 请检查以下情况:
\n
"
if
icr
.
get
(
"ok"
)
is
not
True
:
f
"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}
\n
"
self
.
_notify_textcard
(
f
"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}
\n
"
title
=
"Sqoop 导入失败"
,
content
=
(
f
"任务:{body.get('job_name')}
\n
"
f
"表:{hive_table}
\n
"
f
"code:{icr.get('code')}
\n
"
f
"msg:{self._short(icr.get('msg'))}
\n
"
f
"stderr:{self._short(icr.get('stderr'))}"
)
)
)
# 执行导入, 进行异常检验
raise
RuntimeError
(
"Sqoop 导入失败。详见上面日志。"
)
import_result
=
self
.
_request
(
"sqoop/raw_import"
,
body
,
timeout
=
timeout_sec
)
# 1. cmd执行日志信息
_print_step
(
"COUNT CHECK"
,
ccr
)
print
(
f
"sqoop导入完整日志: {import_result['stdout']}"
)
if
ccr
.
get
(
"ok"
)
is
not
True
:
del
import_result
[
'stdout'
]
self
.
_notify_textcard
(
# 2. 检查导入后的hdfs路径是否成功
title
=
"数量校验失败"
,
hdfs_path_exists_after
=
import_result
[
'hdfs_path_exists_after'
]
content
=
(
if
not
hdfs_path_exists_after
:
f
"任务:{body.get('job_name')}
\n
"
# del import_result['hdfs_path_exists_after']
f
"表:{hive_table}
\n
"
raise
ValueError
(
f
"db_count:{ccr.get('db_count')}
\n
"
f
"❌ sqoop导入失败, 请检查返回结果: {import_result}"
f
"hive_count:{ccr.get('hive_count')}
\n
"
f
"msg:{self._short(ccr.get('msg'))}"
)
)
)
else
:
raise
RuntimeError
(
"数量校验失败。详见上面日志。"
)
print
(
f
"√ sqoop导入成功: {hdfs_path_exists_after}"
)
return
import_result
print
(
"
\n
✅ 全部步骤通过。"
)
# return resp
# # 导入前对变量进行校验
# if hive_table not in hdfs_path or len(hdfs_path) < 10 or len(hive_table) < 3:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. hdfs_path长度<10: 传入长度{len(hdfs_path)}\n"
# f"2. hive_table长度<3: 传入长度{len(hive_table)}\n"
# f"3. hive_table='{hive_table}' 不在 hdfs_path='{hdfs_path}' 路径中,请检查参数设置是否正确"
# )
# # 导入前对变量进行校验
# if m > 1 and split_by is None and split_by in query:
# raise ValueError(
# f"❌ 导入前变量校验失败, 请检查以下情况: \n"
# f"1. map_num的数量大于1时, 必须要指定split_by参数, m: {m}, split_by: {split_by}\n"
# f"2. split_by的字段不在query查询的字段里面: split_by: {split_by}, query: {query}\n"
# )
# # 执行导入, 进行异常检验
# import_result = self._request("sqoop/raw_import", body, timeout=timeout_sec)
# # 1. cmd执行日志信息
# print(f"sqoop导入完整日志: {import_result['stdout']}")
# del import_result['stdout']
# # 2. 检查导入后的hdfs路径是否成功
# hdfs_path_exists_after = import_result['hdfs_path_exists_after']
# if not hdfs_path_exists_after:
# # del import_result['hdfs_path_exists_after']
# raise ValueError(
# f"❌ sqoop导入失败, 请检查返回结果: {import_result}"
# )
# else:
# print(f"√ sqoop导入成功: {hdfs_path_exists_after}")
# return import_result
def
begin
(
self
):
def
begin
(
self
):
return
RemoteTransaction
(
self
.
db
,
self
.
database
,
return
RemoteTransaction
(
self
.
db
,
self
.
database
,
self
.
session
,
self
.
urls
)
self
.
session
,
self
.
urls
)
# ---------------------------------
def
get_remote_engine
(
site_name
:
str
,
db_type
:
str
,
def
get_remote_engine
(
site_name
:
str
,
db_type
:
str
,
...
@@ -294,3 +325,60 @@ def get_remote_engine(site_name: str, db_type: str,
...
@@ -294,3 +325,60 @@ def get_remote_engine(site_name: str, db_type: str,
database
=
site_name
,
database
=
site_name
,
server_urls
=
servers
or
DEFAULT_SERVERS
,
server_urls
=
servers
or
DEFAULT_SERVERS
,
)
)
### 工具类函数
def
df_to_json_records
(
df
:
pd
.
DataFrame
)
->
list
:
"""保证 DataFrame 可安全序列化为 JSON records(处理 NaN / ±Inf)"""
df_clean
=
df
.
copy
()
# 1️⃣ 替换 ±Inf -> NaN
num_cols
=
df_clean
.
select_dtypes
(
include
=
[
np
.
number
])
.
columns
if
len
(
num_cols
):
df_clean
[
num_cols
]
=
df_clean
[
num_cols
]
.
replace
([
np
.
inf
,
-
np
.
inf
],
np
.
nan
)
# 2️⃣ 替换 NaN -> None(注意:有时 astype(object) 不彻底,需用 applymap)
df_clean
=
df_clean
.
applymap
(
lambda
x
:
None
if
pd
.
isna
(
x
)
else
x
)
# 3️⃣ 转为 dict records
return
df_clean
.
to_dict
(
"records"
)
def
clean_json_field_for_orjson
(
v
):
"""清洗单个 JSON 字段的值,使其符合 orjson 要求并避免空字典入库"""
if
v
is
None
or
pd
.
isna
(
v
):
return
None
# 1️⃣ 如果是空字典对象,返回 None
if
isinstance
(
v
,
dict
)
and
not
v
:
return
None
# 2️⃣ 如果是空字符串或仅为 "{}",返回 None
if
isinstance
(
v
,
str
):
stripped
=
v
.
strip
()
if
not
stripped
or
stripped
==
"{}"
:
return
None
try
:
parsed
=
json
.
loads
(
stripped
)
if
isinstance
(
parsed
,
dict
)
and
not
parsed
:
return
None
return
json
.
dumps
(
parsed
,
ensure_ascii
=
False
)
except
Exception
:
return
v
# 非 JSON 字符串则原样保留
return
v
def
fully_clean_for_orjson
(
df
:
pd
.
DataFrame
)
->
pd
.
DataFrame
:
"""全面清洗 DataFrame 以符合 orjson 要求"""
df
=
df
.
replace
([
np
.
inf
,
-
np
.
inf
],
np
.
nan
)
df
=
df
.
applymap
(
lambda
x
:
None
if
pd
.
isna
(
x
)
else
x
)
# 找出所有可能为 JSON 字符串的字段
json_like_cols
=
[
col
for
col
in
df
.
columns
if
col
.
endswith
(
'_json'
)]
# 针对每个 JSON-like 字段,应用清洗函数
for
col
in
json_like_cols
:
df
[
col
]
=
df
[
col
]
.
apply
(
clean_json_field_for_orjson
)
return
df
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment