import os
import sys
import numpy as np
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql


class TemuSearch():

    def __init__(self, site_name='us', search_key='asin', search_value='B0BBNQCXZL', top_k=100):
        self.site_name = site_name
        self.search_key = search_key
        self.search_value = search_value
        self.top_k = top_k
        self.engine = TemplatesMysql().engine
        self.df_features = pd.DataFrame()
        self.query_vector = np.array
        # self.server_ip = "192.168.200.210"
        self.server_ip = "113.100.143.162"

    def read_data(self, search_value, site_name):
        print("search_value, site_name:", search_value, site_name)
        sql = f"""
        SELECT main.asin, main.features, main.platform, a.local_path from (
        SELECT asin, features, platform from us_pictures_features_self WHERE asin in 
            (
            SELECT asin_compet from us_self_asin_compet_amazon WHERE asin='{search_value}' and site='{site_name}'
            )
        UNION all
        SELECT asin, features, platform from us_pictures_features_self WHERE asin in 
            (
            SELECT asin_compet from us_self_asin_compet_temu WHERE asin='{search_value}' and site='{site_name}'
            )
        )
        as main 
        left join 
        (select asin, local_path from us_pictures_local_path_self) a
        on main.asin = a.asin
        """
        # sql = f"""
        # SELECT asin, features, platform from us_pictures_features_self WHERE asin in
        # (
        # SELECT DISTINCT(asin_compet) from us_self_asin_compet_amazon WHERE asin ='{search_value}'
        # )
        # UNION all
        #
        # SELECT asin, features, platform from us_pictures_features_self WHERE asin in (
        #
        # SELECT DISTINCT(asin_compet) from us_self_asin_compet_temu WHERE asin ='{search_value}'
        # )
        #
        # """

        print("sql:", sql)
        df_features = pd.read_sql(sql, con=self.engine)
        print(df_features.shape)
        print(df_features.head())
        return df_features

    def search_api(self, site_name, search_value):
        print("111 search_value, site_name:", search_value, site_name)
        query_vector = self.get_asin_features(search_value)
        self.query_vector = np.array(query_vector)
        self.query_vector = self.query_vector.reshape(1, -1)
        # 计算相似度
        similarities_dict = self.calculate_similarity(site_name=site_name, search_value=search_value, query_vector=query_vector)
        print("similarities_dict:", similarities_dict)

        # return indices_list, distances_list, asin_list, similarities_dict
        return similarities_dict

    def get_asin_features(self, search_value):
        sql = f"select * from {self.site_name}_pictures_features_self where asin='{search_value}'"
        print("sql:", sql)
        df = pd.read_sql(sql, con=self.engine)
        query_vector = eval(list(df.features)[0])
        # print("query_vector:", query_vector[:10])
        return query_vector

    def calculate_similarity(self, site_name, search_value, query_vector):
        df = self.read_data(search_value, site_name)
        platform_list = list(df.platform)
        all_vecs_dict = {asin: eval(features) for asin, features in zip(df.asin, df.features)}

        # query_vector = query_vector.T
        similarities = self.cosine_similarity_matrix(query_vec=query_vector, all_vecs=list(all_vecs_dict.values()))
        # print(similarities)
        similarities_dict = {asin: (platform, similarity, f"http://{self.server_ip}:8000/images/{local_path.replace('/mnt/data/img_data/', '')}") for platform, similarity, asin, local_path in zip(platform_list, similarities, df.asin, df.local_path)}
        # print(similarities_dict)
        # for asin in asin_list:
        #     print(asin, similarities_dict[asin])
        # asin_similarities = {search_value: similarities_dict}
        # print("asin_similarities:", asin_similarities)
        return similarities_dict

    def cosine_similarity_matrix(self, query_vec, all_vecs):
        # 计算相似度
        query_vec_norm = np.linalg.norm(query_vec)
        all_vecs_norm = np.linalg.norm(all_vecs, axis=1)

        print(query_vec_norm.shape)
        print(all_vecs_norm.shape)

        dot_products = np.dot(all_vecs, query_vec)
        similarities = dot_products / (query_vec_norm * all_vecs_norm)
        # 将相似度转换为百分比
        similarities_percentage = similarities * 100
        # 保留所需的小数位数，例如保留两位小数
        similarities_percentage = np.round(similarities_percentage, 2)
        return similarities_percentage


if __name__ == '__main__':
    handle_obj = TemuSearch(site_name='us', search_value='B07GL4C9R9')
    handle_obj.search_api(site_name='us', search_value='B07GL4C9R9')