import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from pyspark.sql.types import BooleanType
from utils.db_util import DBUtil
from utils.common_util import CommonUtil, DateTypes
from utils.hdfs_utils import HdfsUtils
from utils.spark_util import SparkUtil
from pyspark.sql import functions as F
from yswg_utils.common_udf import category_craw_flag

"""
获取当前榜单系统中没有asin详情的数据
"""


class DwdDayAsin(object):

    def __init__(self, site_name, date_info):
        self.site_name = site_name
        self.date_info = date_info
        app_name = f"{self.__class__.__name__}:{site_name}:{date_info}"
        self.spark = SparkUtil.get_spark_session(app_name)
        self.udf_category_craw_flag = F.udf(category_craw_flag, BooleanType())
        self.hive_tb = "dwd_day_asin"

    def run(self):
        # 判断前一日asin有没有导入没有的话警告
        df_self_asin_part = CommonUtil.select_partitions_df(self.spark, "ods_self_asin_detail")
        last_day = CommonUtil.get_day_offset(self.date_info, -1)
        last_7day = CommonUtil.get_day_offset(self.date_info, -7)

        flag = df_self_asin_part.filter(
            f"date_type = '{DateTypes.day.name}' and site_name = '{self.site_name}' and date_info = '{last_day}'").count()

        if flag == 0:
            CommonUtil.send_wx_msg(['wujicang', 'hezhe'], "异常提醒", f"{last_day}日asin无数据,请检查")

        month = CommonUtil.reformat_date(date_info, "%Y-%m-%d", "%Y-%m", )
        sql = f"""
with asin_list1 as (
	select asin, category_first_id
	from dwt_bsr_asin_detail
	where title is null
	  and site_name = '{self.site_name}'
	  and date_info = '{month}'
),
	 asin_list2 as (
		 select asin, category_first_id
		 from dwt_nsr_asin_detail
		 where title is null
		   and site_name = '{self.site_name}'
		   and date_info = '{month}'
	 ),
	 asin_exist as (
		 select asin
		 from dim_cal_asin_history_detail
		 where site_name = '{self.site_name}'
	 ),
--  失效的asin
	 asin_lost as (
        select asin
        from dwd_day_asin
        where date_info >= '{last_7day}'
          and date_info < '{date_info}'
          and site_name = '{self.site_name}'
          and queue_name = 'day_queue'
          and craw_state = 2
	 ),
	 asin_all as (
		 select asin, category_first_id
		 from asin_list1
		 union all
		 select asin, category_first_id
		 from asin_list2
	 )

select asin_all.asin, category_first_id
from asin_all
		 left anti join asin_exist on asin_all.asin = asin_exist.asin
		 left anti join asin_lost  on asin_all.asin = asin_lost.asin
"""
        #  left anti join 类似 not exist
        print("======================查询sql如下======================")
        print(sql)
        df_save = self.spark.sql(sql)
        #  过滤不爬的asin
        df_save = df_save.filter(self.udf_category_craw_flag(F.col("category_first_id")))
        # limit 限制数量
        df_save = df_save.dropDuplicates(['asin'])
        count_all = df_save.count()
        count = 40 * 10000
        df_save = df_save.limit(count)
        df_save = df_save.select(
            "asin",
            F.lit("day_queue").alias("queue_name"),
            F.lit(0).alias("craw_state"),
            F.lit(self.site_name).alias("site_name"),
            F.lit(self.date_info).alias("date_info"),
        ).orderBy(F.col("asin").desc())

        count = df_save.count()

        partition_dict = {
            "site_name": self.site_name,
            "date_info": self.date_info,
        }
        hdfs_path = CommonUtil.build_hdfs_path(self.hive_tb, partition_dict=partition_dict)
        HdfsUtils.delete_file_in_folder(hdfs_path)
        df_save = df_save.repartition(1)
        partition_by = list(partition_dict.keys())
        print(f"当前存储的表名为:{self.hive_tb},分区为{partition_by}", )
        df_save.write.saveAsTable(name=self.hive_tb, format='hive', mode='append', partitionBy=partition_by)
        CommonUtil.send_wx_msg(['wujicang'], "提醒", f"{date_info}日asin计算成功,条数为{count}/{count_all}")

    def run_update(self):
        # 获取未爬取的数据集合
        craw_sql = f"""
        select asin
        from {self.site_name}_self_all_syn
        where data_type = 4
          and date_info = '{self.date_info}'
          and state in (4, 12, 13)
        """
        conn_info = DBUtil.get_connection_info("mysql", self.site_name)
        craw_asin_df = SparkUtil.read_jdbc_query(
            session=self.spark,
            url=conn_info["url"],
            pwd=conn_info["pwd"],
            username=conn_info["username"],
            query=craw_sql
        ).cache()
        craw_asin_df = craw_asin_df.dropDuplicates(['asin'])
        # 标记
        sql = f"""
select asin, queue_name, craw_state,date_info,site_name
from dwd_day_asin
where date_info = '{self.date_info}'
  and site_name = '{self.site_name}'
  and queue_name = 'day_queue'
"""
        df_update = self.spark.sql(sql)
        if df_update.count() == 0:
            return
        df_update = df_update.join(craw_asin_df, on=['asin'], how="left").select(
            df_update['asin'].alias("asin"),
            F.col("queue_name"),
            F.when(craw_asin_df['asin'].isNotNull(), F.lit("2")).otherwise(F.lit("1")).alias("craw_state"),
            F.col("date_info"),
            F.col("site_name"),
        )
        # 更新
        partition_dict = {
            "site_name": self.site_name,
            "date_info": self.date_info,
        }
        #  更新表
        CommonUtil.save_or_update_table(
            spark_session=self.spark,
            hive_tb_name=self.hive_tb,
            partition_dict=partition_dict,
            df_save=df_update
        )
        pass


if __name__ == '__main__':
    site_name = CommonUtil.get_sys_arg(1, None)
    date_info = CommonUtil.get_sys_arg(2, None)
    update_flag = CommonUtil.get_sys_arg(len(sys.argv) - 1, None) == 'update'
    obj = DwdDayAsin(site_name, date_info)
    if update_flag:
        #  更新数据
        obj.run_update()
    else:
        # 构建数据
        obj.run()