import functools import os import sys sys.path.append(os.path.dirname(sys.path[0])) from utils.spark_util import SparkUtil from utils.db_util import DBUtil from utils.common_util import CommonUtil from utils.redis_utils import RedisUtils from pyspark.sql import functions as F redis_key = "asinDetail:jmInfo" def save_to_redis_map(iterator, redis_key, hash_field_key, ttl: -1, batch: int): redis_cli = RedisUtils.get_redis_client_by_type(db_type='selection', dbNum=0) cnt = 0 pipeline = redis_cli.pipeline() for json_row in iterator: pipeline.hset(redis_key, eval(json_row)[hash_field_key], json_row) cnt += 1 if cnt > 0 and cnt % batch == 0: pipeline.execute() if cnt % batch != 0: pipeline.execute() pipeline.close() if ttl > 0: redis_cli.expire(redis_key, ttl) redis_cli.close() pass def calc_and_save(): spark = SparkUtil.get_spark_session("asin_jm_info_redis") conn_info = DBUtil.get_connection_info("mysql", "us") jm_info_df = SparkUtil.read_jdbc_query( session=spark, url=conn_info["url"], pwd=conn_info["pwd"], username=conn_info["username"], query=""" select asin, count(id) as auctions_num, count((case when sku != '' then sku end)) as skus_num_creat from product_audit_asin_sku where asin != '' and length(asin) = 10 group by asin """ ).cache() # 获取parentAsin parent_asin_df = spark.sql(""" select asin, parent_asin from dim_asin_variation_info where site_name = 'us' """) df_tmp = jm_info_df.join(parent_asin_df, on=['asin'], how='left').select( F.col("asin"), F.col("auctions_num"), F.col("skus_num_creat"), F.col("parent_asin"), ) # 计算变体竞卖数据 parent_all_df = df_tmp.where("parent_asin is not null") \ .groupby(F.col("parent_asin")) \ .agg( F.sum("auctions_num").alias("auctions_num_all"), F.sum("skus_num_creat").alias("skus_num_creat_all"), ).cache() save_all = df_tmp.join(parent_all_df, on=['parent_asin'], how='left').select( F.col("asin"), F.col("auctions_num"), F.col("skus_num_creat"), F.col("parent_asin"), F.col("auctions_num_all"), F.col("skus_num_creat_all"), ) save_all.write.saveAsTable(name="tmp_jm_info", format='hive', mode='overwrite') print("success") pass def save_to_redis(): spark = SparkUtil.get_spark_session("asin_jm_info_redis") df_all = spark.sql(""" select asin, auctions_num as auctionsNum, skus_num_creat as skusNumCreat, parent_asin as parentAsin, auctions_num_all as auctionsNumAll, skus_num_creat_all as skusNumCreatAll from tmp_jm_info; """) df_all.toJSON().foreachPartition( functools.partial(save_to_redis_map, batch=1000, redis_key=redis_key, hash_field_key='asin', ttl=3600 * 24 * 7) ) print("success") pass def check(): redis_cli = RedisUtils.get_redis_client_by_type(db_type='selection', dbNum=0) size = redis_cli.hlen(redis_key) if size < 10000: CommonUtil.send_wx_msg(['wujicang'], title='数据同步警告', content=f"竞卖数据【{redis_key}】数据总数为{size}请检查导出是否异常!!") pass redis_cli.close() pass if __name__ == '__main__': arg = CommonUtil.get_sys_arg(1, None) if "calc" == arg: calc_and_save() elif "redis" == arg: save_to_redis() elif "check" == arg: check() pass pass