import time
import json
import logging
import hashlib
from playwright.sync_api import Browser, BrowserContext, Page, sync_playwright, TimeoutError, Error
import redis as rd
import ddddocr
import re
import base64
from io import BytesIO
from PIL import Image


# REDIS = {
#     # 'host': '127.0.0.1',
#     'host': '120.79.147.190',
#     'port': 6379,
#     'password': 'Vm5vQH4ydFXh',
#     'db': 10
# }

REDIS = {
    # 'host': '127.0.0.1',
    'host': 'wx.yswg.com.cn',
    'port': 6379,
    'password': 'yswg@2019',
    'db': 1
}
#

def singleton(cls, *args, **kw):
    """singleton mode.

    :param cls: classname
    :param args: args.
    :param kw: kwargs.
    :return:
    """

    instances = {}

    def _singleton():
        if cls not in instances:
            instances[cls] = cls(*args, **kw)
        return instances[cls]

    return _singleton


def md5(src: str, algorithm: str = "md5", digits: int = 32) -> str:
    """md5 algorithms.

    :param src: original string.
    :param algorithm: algorithm method.
    :param digits: 16 length or 32 length.
    :return: string.
    """

    algorithm = hashlib.new(algorithm)
    algorithm.update(src.encode('utf8'))
    if digits == 16:
        return algorithm.hexdigest()[8:24]
    else:
        return algorithm.hexdigest()


@singleton
class Redis(object):
    def __init__(self):
        self.host = REDIS['host']
        self.port = REDIS['port']
        self.db = REDIS['db']
        self.password = REDIS['password']

    def get_instance(self):
        self.pool = rd.ConnectionPool(
            host=self.host,
            port=self.port,
            db=self.db,
            password=self.password,
            max_connections=3,
            socket_timeout=5,
            socket_connect_timeout=5,
            retry_on_timeout=True,
        )
        return rd.Redis(connection_pool=self.pool)


def sadd(key, value, use_md5=True):
    """add key-value to the sorted set.

    :param key: key
    :param value: value
    :param use_md5: weather use md5.
    :return: True for done, False for not.
    """

    r = Redis().get_instance()
    if use_md5:
        added = r.sadd(key, md5(value, digits=16))
    else:
        added = r.sadd(key, value)
    return added == 1


def spop(key, count=20) -> list:
    """spop

    :param count:
    :param key:
    :return:
    """

    r = Redis().get_instance()
    return r.spop(key, count)


def ladd(key, value, use_md5=True):
    """add key-value to the sorted set.

    :param key: key
    :param value: value
    :param use_md5: weather use md5.
    :return: True for done, False for not.
    """

    r = Redis().get_instance()
    if use_md5:
        added = r.lpush(key, md5(value, digits=16))
    else:
        added = r.lpush(key, value)
    return added == 1


def listpop(key) -> list:
    """lpop

    :param count:
    :param key:
    :return:
    """

    r = Redis().get_instance()
    return r.lpop(key)


def xadd(key, data):
    r = Redis().get_instance()
    return r.xadd(key, data)


class ChinataxSpider(object):
    seed_key = 'finance:sp_invoice_queue'
    save_key = 'finance:sp_invoice_result'

    def __init__(self):
        logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s',
                            level=logging.INFO)
        self.browser: Browser = None
        self.context: BrowserContext = None
        self.page: Page = None
        self.padding_error = 1
        self.seeds = [
            {
                "u_key": "1934929493120057346",
                "fpdm": "",
                "fphm": "25429165833000096487",
                "kprq": "20250519",
                "kjje": "264.50",
                "jym": ""
            }
        ]

    def base64_to_image(self, base64_str):
        """
        将base64字符串转换成图片
        :param base64_str:
        :param color:
        :return:
        """
        base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
        byte_data = base64.b64decode(base64_data)
        image_data = BytesIO(byte_data)
        img = Image.open(image_data)
        img = img.convert("RGB")
        return img

    def get_img_base64(self):
        yzminfo = ""
        color = "black"
        img_base64 = ""
        count = 1
        while len(img_base64) <= len("images/code.png") and count <= 5:
            yzminfo = self.page.query_selector("#yzminfo").as_element().text_content()
            img_base64 = self.page.query_selector("#yzm_img").get_attribute("src")
            self.page.wait_for_timeout(1000)
            count = count + 1

        if "蓝色" in yzminfo:
            color = 'blue'
        if "红色" in yzminfo:
            color = 'red'

        return color, img_base64

    def get_img(self):
        count = 1
        color, img_base64 = self.get_img_base64()
        print(f"第{count}次是{color}色")
        while color != 'black':
            self.page.query_selector("#yzm_img").click()
            self.page.wait_for_timeout(3000)
            self.page.wait_for_load_state()
            count = count + 1
            color, img_base64 = self.get_img_base64()
            color = color
            img_base64 = img_base64
            print(f"第{count}次是{color}色")
            # self.base64_to_image(img_base64).save(f"test第{count}次.png")
        # image = self.base64_to_image(img_base64)
        # image.save("test.png")
        return color, img_base64

    def ddddocr_imge_get_code(self, color, img_base64):
        ocr = ddddocr.DdddOcr()
        img = self.base64_to_image(img_base64)
        if color == 'black':
            return ocr.classification(img)
        #  todo 预处理
        return ocr.classification(img)

    def get_seed(self):
        seeds = spop(self.seed_key, 1)
        # if self.seeds:
        #     seeds = self.seeds.pop()
        # else:
        #     return None
        if seeds:
            seed = json.loads(seeds[0])
            # seed = seeds
            seed = {k: v.strip() if isinstance(v, str) else "" for k,v in seed.items()}
            if seed.get("fphm") and seed.get("kprq"):
                return seed
            else:
                error_msg = "seed error"
                data = {
                    "u_key": seed.get("u_key"),
                    "dom": error_msg,
                }
                xadd(self.save_key, data)
                logging.info(f"seeds 缺少字段 {seed}")
                return None
        else:
            return None

    def recaptcha(self):
        color, img_base64 = self.get_img()

        code = self.ddddocr_imge_get_code(color, img_base64)

        if color == 'black':
            self.page.locator("#yzm").fill(code)
            # 点击空白页失去焦点
            self.page.locator("#pageshow").click()
            if "display: none" not in self.page.query_selector("#checkfp").get_attribute("style"):
                self.page.locator("#checkfp").click()
                self.page.wait_for_timeout(2000)
                # 判断验证码是否通过
                if not self.page.query_selector("#dialog-body"):
                    if error_msg := self.page.query_selector("#popup_message"):
                        if "超过该张发票当日查验次数" in error_msg.text_content():
                            self.page.locator("#popup_ok").click()
                            self.page.wait_for_timeout(2000)
                            return "count error"
                        if "验证码请求次数过于频繁" in error_msg.text_content():
                            self.page.locator("#popup_ok").click()
                            self.page.wait_for_timeout(2000)
                            return "recaptcha count error"
                    self.page.locator("#popup_ok").click()
                    # 刷新验证码
                    self.page.locator("#yzm_img").click()
                    self.page.wait_for_timeout(2000)
                    logging.info("验证码处理错误")
                    return False
                else:
                    return True
            else:
                return "seed error"
        else:
            return False

    def crawl(self, url, seed):
        # 需要打开的网站
        self.page.goto(url)
        self.page.wait_for_timeout(1000)
        logging.info(f"fpdm --> {seed.get('fpdm', '')}")
        logging.info(f"seed --> {seed}")
        # fpdm --> None
        if seed.get('fpdm'):
            self.page.locator("#fpdm").fill(seed.get("fpdm", ""))
            self.page.wait_for_timeout(5000)
        self.page.locator("#fphm").fill(seed.get("fphm"))
        self.page.wait_for_timeout(1000)
        self.page.locator("#kprq").fill(seed.get("kprq"))
        self.page.wait_for_timeout(3000)
        if "开具金额" in self.page.query_selector(
                "span[id='context']").text_content() or "价税合计" in self.page.query_selector(
                "span[id='context']").text_content() or '票价' in self.page.query_selector(
                "span[id='context']").text_content():
            kjje = seed.get("kjje")
        else:
            kjje = seed.get("jym", "")[-6::]
        self.page.locator("#kjje").fill(kjje)

        self.page.locator("#pageshow").click()
        self.page.wait_for_timeout(2000)
        if error_msg := self.page.query_selector("#fpdmjy").text_content().strip():
            data = {
                "u_key": seed.get("u_key"),
                "dom": error_msg,
            }
            xadd(self.save_key, data)
            logging.info(f"fpdmjy - >>{error_msg}")
            self.page.close()
            self.context.close()
        elif error_msg := self.page.query_selector("#fphmjy").text_content().strip():
            data = {
                "u_key": seed.get("u_key"),
                "dom": error_msg,
            }
            xadd(self.save_key, data)
            logging.info(f"fphmjy - >>{error_msg}")
            self.page.close()
            self.context.close()
        elif self.page.query_selector("xpath=.//div[@class='tip_common_wrong font_red tip_common_right']"):
            data = {
                "u_key": seed.get("u_key"),
                "dom": "发票号码有误!",
            }
            xadd(self.save_key, data)
            logging.info(f"{error_msg}")
            self.page.close()
            self.context.close()
        elif error_msg := self.page.query_selector("#kprqjy").text_content().strip():
            data = {
                "u_key": seed.get("u_key"),
                "dom": error_msg,
            }
            xadd(self.save_key, data)
            logging.info(f"kprqjy - >>{error_msg}")
            self.page.close()
            self.context.close()
        elif error_msg := self.page.query_selector("#kjjejy").text_content().strip():
            data = {
                "u_key": seed.get("u_key"),
                "dom": error_msg,
            }
            xadd(self.save_key, data)
            logging.info(f"kjjejy - >>{error_msg}")
            self.page.close()
            self.context.close()
        else:
            error = 0
            for i in range(8):
                # 判断验证码是否通过
                if error_msg := self.recaptcha():
                    if error_msg in ["count error", "seed error"]:
                        data = {
                            "u_key": seed.get("u_key"),
                            "dom": error_msg,
                        }
                        xadd(self.save_key, data)
                        logging.info(f"{error_msg} {seed}")
                        error = 1
                        self.page.close()
                        self.context.close()
                        break
                    if error_msg == "recaptcha count error":
                        logging.info(f"recaptcha count error {seed}")
                        break
                    self.page.frame_locator("#dialog-body")
                    self.page.wait_for_timeout(2000)
                    text_contents = self.page.frame("dialog-body").content()
                    data = {
                        "u_key": seed.get("u_key"),
                        "dom": text_contents,
                    }
                    xadd(self.save_key, data)
                    logging.info(f"seed spider ok --> {seed}")
                    error = 1
                    self.page.close()
                    self.context.close()
                    break
                else:
                    logging.info("验证码处理失败重试")

            if error == 0:
                logging.info("失败8次处理,将任务重新推送到redis")
                sadd(self.seed_key, json.dumps(seed), use_md5=False)
                self.page.close()
                self.context.close()

    def change_user(self):
        user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
        self.context = self.browser.new_context(
            locale='en-GB',
            bypass_csp=True,
            user_agent=user_agent,
            ignore_https_errors=True,
            no_viewport=True,
        )
        # self.page = self.browser.new_page()
        self.page = self.context.new_page()

        # 模拟浏览器参数
        self.page.locator("body").click()
        js = """
        Object.defineProperties(navigator, {webdriver:{get:()=>undefined}});
        """
        self.page.add_init_script(js)
        # self.page.add_init_script(
        #     "const newProto = navigator.__proto__; delete newProto.webdriver; navigator.__proto__ = newProto;")

    def run(self):
        while True:
            try:
                seed = self.get_seed()
            except rd.exceptions.ConnectionError as e:
                logging.info(f"ConnectionError error {e}")
                continue

            try:
                if seed:
                    logging.info("获取任务成功")
                    self.change_user()
                    url = "https://inv-veri.chinatax.gov.cn/index.html"
                    self.crawl(url, seed)
                else:
                    time.sleep(30)
                    logging.info('no task sleep 30s')
            except Error as e:
                logging.info(f"playwright error {e}")
                self.page.close()
                self.context.close()
                sadd(self.seed_key, json.dumps(seed), use_md5=False)
                continue
            except rd.exceptions.ConnectionError as e:
                logging.info(f"ConnectionError error {e}")
                self.page.close()
                self.context.close()
                sadd(self.seed_key, json.dumps(seed), use_md5=False)
                continue
            except Exception as e:
                self.page.close()
                self.context.close()
                if f"{e}" == "Incorrect padding":
                    sadd(self.seed_key, json.dumps(seed), use_md5=False)
                    logging.info(f"Incorrect padding error {e}")
                    self.padding_error += 1
                    if self.padding_error >= 5:
                        time.sleep(1200)
                        logging.info(f"等待1200分钟")
                        self.padding_error = 1
                    continue
                else:
                    data = {
                        "u_key": seed.get("u_key"),
                        "dom": f"code error",
                    }
                    logging.info(f"Incorrect padding error {e}")
                    xadd(self.save_key, data)
                    continue

    def main(self):
        headless = False
        # headless = True
        logging.info(f"{headless}")
        with sync_playwright() as _playwright:
            # self.browser = _playwright.chromium.launch_persistent_context(
            #     headless=False,
            #     executable_path="C:\Program Files\Google\Chrome\Application\chrome.exe",
            #     user_data_dir=r"D:\chrome_data\test01",
            #     ignore_https_errors=True,
            #     no_viewport=True,
            #     bypass_csp=True
            # )
            self.browser = _playwright.chromium.launch(
                headless=False,
                # executable_path="C:\Program Files\Google\Chrome\Application\chrome.exe",
                executable_path=r"C:\Program Files (x86)\ChatAI Chrome\ChatAI_Chrome.exe",
            )
            self.run()


if __name__ == '__main__':
    chinatax = ChinataxSpider()
    chinatax.main()
    # seed = {
    #     "u_key": "xxx",
    #     "fpdm": "044002311111",
    #     "fphm": "59210491",
    #     "kprq": "20231214",
    #     "kjje": "303791"
    # }
    # ladd('finance:sp_invoice_queue', json.dumps(seed), use_md5=False)


