import os
import re
import sys
import traceback

import numpy as np
import pandas as pd
import faiss
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.db_util import DbTypes, DBUtil
from vgg_model import VGGNet
from curl_cffi import requests, Curl
from lxml import etree


class ImgSearch():

    def __init__(self, site_name='us', search_key='asin', search_value='B0BBNQCXZL', top_k=100):
        self.site_name = site_name
        self.search_key = search_key
        self.search_value = search_value
        self.top_k = top_k
        self.server_ip = "113.100.143.162:8000"
        self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
        self.vgg_model = VGGNet()
        self.index_path = rf"/mnt/data/img_data/img_index/us/amazon_inv/knn.index"
        self.my_index = self.load_index()  # 加载索引 -- 后续更新

    def load_index(self):
        my_index = faiss.read_index(self.index_path)
        print("加载索引完成, type(my_index):", type(my_index))
        return my_index

    def get_img_feature(self, search_key, search_value, file_path, site_name='us', img_type='amazon_inv'):
        print("search_key, search_value, file_path:", search_key, search_value, file_path)
        if search_key == 'img_unique':
            sql = f"select * from img_features where site_name='{site_name}' and img_type='{img_type}' and img_unique='{search_value}';"
            df = pd.read_sql(sql, con=self.engine_doris)
            if df.shape[0] == 0:
                if len(search_value) == 10:
                    query_vector = self.download_img(asin=search_value, site=site_name)
                else:
                    query_vector = []
            else:
                query_vector = eval(list(df.features)[0])
            return query_vector

        elif search_key == 'file':
            query_vector = self.vgg_model.vgg_extract_feat(file=file_path, file_type='bytes')
            return query_vector
        else:
            print("不符合的查询方式,仅支持img_unique和图片file两种方式")

    def calculate_similarity(self, index_list, query_vector, site_name='us', img_type='amazon_inv'):
        sql = f"""    
            SELECT main.img_unique, a.features FROM  
            (select img_unique from img_id_index where `index` in {index_list}) main 
            inner join img_features a 
            on main.img_unique=a.img_unique
            where a.site_name='{site_name}' and img_type='{img_type}';
        """
        df = pd.read_sql(sql, con=self.engine_doris)
        df = df.drop_duplicates(['img_unique'])
        all_vecs_dict = {img_unique: eval(features) for img_unique, features in zip(df.img_unique, df.features)}
        similarities = self.cosine_similarity_matrix(query_vec=query_vector, all_vecs=list(all_vecs_dict.values()))
        # similarities_dict = {asin: (similarity, title, f"http://{self.server_ip}/images/{site_name}/{asin[:1]}/{asin[:2]}/{asin[:3]}/{asin[:4]}/{asin[:5]}/{asin[:6]}/{asin}.jpg") for similarity, title, asin in zip(similarities, df.title, df.asin) if asin}
        # similarities_dict = {
        #     img_unique: (similarity,
        #                  f"http://{self.server_ip}/images/{img_type}/{site_name}/{img_unique.split('_')[0]}/{img_unique}.jpg")
        #     for similarity, img_unique in zip(similarities, df.img_unique) if img_unique
        # }
        similarities_list = []
        for similarity, img_unique in zip(similarities, df.img_unique):
            img_url = f"http://{self.server_ip}/images/{img_type}/{site_name}/{img_unique.split('_')[0]}/{img_unique}.jpg"
            img_unique = re.sub(r'_', '@@', img_unique, count=3)
            # print(f"img_unique: {img_unique}")
            if len(img_unique.split("@@")) == 4:
            # if len(img_unique.split("@@")) == 4 and similarity > 0:
                img_id, img_type_ = img_unique.split("@@")[-2], img_unique.split("@@")[-1]
                similarities_list.append(
                    {
                        "img_id": img_id,
                        "img_type": img_type_,
                        "similarity": similarity,
                        "img_url": img_url,
                    }
                )
        df_similarities = pd.DataFrame(similarities_list)
        return df_similarities

    def cosine_similarity_matrix(self, query_vec, all_vecs):
        # 计算相似度
        query_vec_norm = np.linalg.norm(query_vec)
        all_vecs_norm = np.linalg.norm(all_vecs, axis=1)

        print(query_vec_norm.shape)
        print(all_vecs_norm.shape)

        dot_products = np.dot(all_vecs, query_vec)
        similarities = dot_products / (query_vec_norm * all_vecs_norm)
        # 将相似度转换为百分比
        similarities_percentage = similarities * 100
        # 保留所需的小数位数，例如保留两位小数
        similarities_percentage = np.round(similarities_percentage, 2)
        return similarities_percentage

    def cosine_similarity_two_img(self, file1_path, file2_path):
        query_vector1 = self.vgg_model.vgg_extract_feat(file=file1_path, file_type='bytes')
        query_vector2 = self.vgg_model.vgg_extract_feat(file=file2_path, file_type='bytes')
        query_vector1_norm = np.linalg.norm(query_vector1)  # 向量1的模
        query_vector2_norm = np.linalg.norm(query_vector2)  # 向量2的模
        dot_product = np.dot(query_vector1, query_vector2)  # 向量的点积
        similarity = dot_product / (query_vector1_norm * query_vector2_norm)
        # similarity = np.float32(similarity)
        similarity = np.round(similarity * 100, 2)
        return similarity

    def search_api(self, search_key, search_value, file_path, site_name='us', img_type='amazon_inv', top_k=100):
        # 获取要查询的asin/图片文件的向量
        query_vector = self.get_img_feature(search_key, search_value, file_path, site_name, img_type=img_type)
        self.query_vector = np.array(query_vector)
        self.query_vector = self.query_vector.reshape(1, -1)
        distances, indices = self.my_index.search(self.query_vector, top_k)
        index_list = tuple(indices.tolist()[0])
        # 计算相似度
        similarities_dict = self.calculate_similarity(site_name=site_name, img_type=img_type, index_list=index_list, query_vector=query_vector)
        return similarities_dict

    def download_img(self, asin, site='us'):
        for i in range(5):
            try:
                if site == 'us':
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                if site == "us":
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                elif site == 'uk':
                    asin_url = f'https://www.amazon.co.uk/dp/{asin}'  # 站点url
                elif site == 'de':
                    asin_url = f'https://www.amazon.de/dp/{asin}'
                elif site == 'fr':
                    asin_url = f'https://www.amazon.fr/dp/{asin}'
                elif site == 'es':
                    asin_url = f'https://www.amazon.es/dp/{asin}'
                elif site == 'it':
                    asin_url = f'https://www.amazon.it/dp/{asin}'
                headers = {
                    'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
                    'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'zh-CN,zh;q=0.9',
                    'Cache-Control': 'no-cache',
                    'Pragma': 'no-cache',
                    'Sec-Ch-Ua-Mobile': '?0', 'Sec-Ch-Ua-Platform': '"Windows"',
                    'Sec-Ch-Ua-Platform-Version': '"10.0.0"',
                    'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none',
                    'Sec-Fetch-User': '?1',
                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36',
                    'Viewport-Width': '1920'}
                curl = Curl(cacert="/path/to/your/cert")
                session = requests.Session(curl=curl)
                print('asin_url::', asin_url)
                resp = session.get(asin_url, headers=headers, timeout=10, verify=False, impersonate="chrome110")  # 获取网页
                image_url = self.xpath_imgurl(resp.text)
                if image_url:
                    str_ = re.findall(r'\._(.*)\.jpg', image_url)[0]
                    _url = image_url.replace('._' + str_, '')
                    print('获取图片url:', _url)
                    resp_img = session.get(_url, headers=headers, timeout=10, verify=False,
                                           impersonate="chrome110")  # 获取网页
                    asin_upper = asin.upper()
                    path_1 = fr"/mnt/data/img_data/tmp_img/"
                    if os.path.exists(path_1) == False:  # 判断路径是否存在
                        os.makedirs(path_1)
                    asin_path = rf"{path_1}{asin_upper}.jpg"
                    with open(asin_path, 'wb') as f:  # 打开写入到path路径里-二进制文件，返回的句柄名为f
                        f.write(resp_img.content)  # 往f里写入r对象的二进制文件
                    query_vec = self.vgg_model.vgg_extract_feat(file=asin_path, file_type='bytes')
                    return query_vec
            except Exception as e:
                print(e, f"\n{traceback.format_exc()}")
        # return query_vec

    def xpath_imgurl(self, resp):
        print(2222222222222222222)
        img_url_list = ["//div[@id='imgTagWrapperId']/img/@src", '//div[@id="img-canvas"]/img/@src',
                        '//div[@id="imgTagWrapperId"]/img/@src', '//div[@id="img-canvas"]/img/@src',
                        '//div[@class="image-wrapper"]/img/@src', '//div[@id="mainImageContainer"]/img/@src',
                        '//img[@id="js-masrw-main-image"]/@src', '//div[@id="a2s-dp-main-content"]//img/@src',
                        '//div[@id="ppd-left"]//img/@src',
                        '//div[@id="ebooks-img-canvas"]//img[@class="a-dynamic-image frontImage"]/@src',
                        '//div[@id="ppd-left"]//img/@src', '//div[@class="a-column a-span12 a-text-center"]/img/@src',
                        '//img[0]/@src', '//img[@id="seriesImageBlock"]/@src', '//img[@class="main-image"]/@src',
                        '//img[@class="mainImage"]/@src', '//img[@id="gc-standard-design-image"]/@src',
                        '//div[@class="a-row a-spacing-medium"]//img[1]/@src',
                        '//div[@class="a-image-container a-dynamic-image-container greyBackground"]/img/@src']
        response_s = etree.HTML(resp)
        for i in img_url_list:
            image = response_s.xpath(i)
            if image:
                image_url = image[0]
                break
            else:
                image_url = None
        return image_url
