usr_mask_update.py 3.67 KB
import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from utils.db_util import DBUtil, DbTypes
from utils.ssh_util import SSHUtil
from utils.common_util import CommonUtil
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil

from pyspark.sql import functions as F, Window
from pyspark.sql import SparkSession


def get_update_df(spark: SparkSession, module: str, site_name: str, fileds: list):
    info = DBUtil.get_connection_info(DbTypes.postgresql.name, 'us')
    df_all = []
    for filed in fileds:
        sql = f"""
		 select edit_key_id,
				val_after   as {filed}
		 from (
				  select filed,
						 edit_key_id,
						 val_after,
						 row_number() over ( partition by module,site_name, filed, edit_key_id order by id desc ) as last_row
				  from sys_edit_log
				  where val_after is not null
					and edit_key_id is not null
					and edit_key_id != ''
					and user_id != 'admin'
                    and site_name = '{site_name}'
					and module in ('{module}')
					and filed in ('{filed}')
			  ) tmp
		 where last_row = 1
"""
        print(sql)
        df = SparkUtil.read_jdbc_query(
            session=spark,
            url=info['url'],
            pwd=info['pwd'],
            username=info['username'],
            query=sql
        )
        df_all.append(df)
        pass

    df_rel = df_all[0]

    for df in df_all[1:]:
        df_rel = df_rel.join(df, how="fullouter", on=['edit_key_id'])

    return df_rel


def update():
    spark = SparkUtil.get_spark_session("usr_mask_update")
    config_row = {
        "ABA搜索词(新)": {
            "fileds": ['usr_mask_type', 'usr_mask_progress'],
            "key": "search_term",
            "hive_table": "dwt_aba_st_analytics",
            "rel_table": "us_aba_last_month",
        },
        # "店铺Feedback": {
        #     "fileds": ['usr_mask_type', 'usr_mask_progress'],
        #     "key": "search_term",
        #     "hive_table": "dwt_aba_st_analytics"
        # },
        # "AbaWordYear": {
        #     "fileds": ['usr_mask_type', 'usr_mask_progress'],
        #     "key": "search_term",
        #     "hive_table": "dwt_aba_last365"
        # }
    }

    for module in config_row.keys():
        fileds = config_row[module]['fileds']
        key = config_row[module]['key']
        hive_table = config_row[module]['hive_table']

        site_name = 'us'
        target_update_tb = "usr_mask_update_tmp"
        df_all = get_update_df(spark, module, site_name, fileds=fileds)
        info = DBUtil.get_connection_info(DbTypes.postgresql_cluster.name, 'us')
        df_all.write.jdbc(info['url'], target_update_tb, mode='overwrite',
                          properties={'user': info['username'], 'password': info['pwd']})

        month_rows = CommonUtil.select_partitions_df(spark, hive_table) \
            .where(f"site_name= '{site_name}' and date_type= 'month' and date_info >= '2024-01' ") \
            .select(F.col("date_info")) \
            .sort(F.col("date_info").desc()) \
            .toPandas().to_dict(orient='records')
        suffixs = [it['date_info'] for it in month_rows]

        DBUtil.get_db_engine(DbTypes.postgresql_cluster.name, site_name)

        for suffix in suffixs:
            suffix = suffix.replace("-", "_")
            update_sql = f"""
            update us_aba_last_month_{suffix} tb1
            set {",".join([f"{it} = tmp.{it}" for it in fileds])}
            from {target_update_tb} tmp
            where tb1.{key} = tmp.edit_key_id;
    """
            # 关联更新
            DBUtil.exec_sql(DbTypes.postgresql_cluster.name, site_name, update_sql, True)
            pass

        print("success")
        pass

    pass


if __name__ == '__main__':
    update()
    pass