pictures_search.py 15 KB
import os
import sys

import faiss
import numpy as np
import pandas as pd

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates_mysql import TemplatesMysql
from vgg_model import VGGNet
from curl_cffi import requests, Curl
import requests as requests2
import random
from lxml import etree
import traceback
import os
import re


class PicturesSearch(object):

    def __init__(self, site_name='us', search_key='asin', search_value='B0BBNQCXZL', top_k=100):
        self.site_name = site_name
        self.search_key = search_key
        self.search_value = search_value
        self.top_k = top_k
        self.index_path = rf"/mnt/data/img_data/index/knn.index"
        self.my_index = self.load_index()
        self.vgg_model = VGGNet()
        self.query_vector = np.array
        # self.engine = TemplatesMysql().engine
        self.engine_pg = TemplatesMysql().engine_pg
        # self.server_ip = "192.168.200.210"
        self.server_ip = "113.100.143.162:8000"

    def xpath_imgurl(self, resp):
        print(2222222222222222222)
        img_url_list = ["//div[@id='imgTagWrapperId']/img/@src", '//div[@id="img-canvas"]/img/@src',
                        '//div[@id="imgTagWrapperId"]/img/@src', '//div[@id="img-canvas"]/img/@src',
                        '//div[@class="image-wrapper"]/img/@src', '//div[@id="mainImageContainer"]/img/@src',
                        '//img[@id="js-masrw-main-image"]/@src', '//div[@id="a2s-dp-main-content"]//img/@src',
                        '//div[@id="ppd-left"]//img/@src',
                        '//div[@id="ebooks-img-canvas"]//img[@class="a-dynamic-image frontImage"]/@src',
                        '//div[@id="ppd-left"]//img/@src', '//div[@class="a-column a-span12 a-text-center"]/img/@src',
                        '//img[0]/@src', '//img[@id="seriesImageBlock"]/@src', '//img[@class="main-image"]/@src',
                        '//img[@class="mainImage"]/@src', '//img[@id="gc-standard-design-image"]/@src',
                        '//div[@class="a-row a-spacing-medium"]//img[1]/@src',
                        '//div[@class="a-image-container a-dynamic-image-container greyBackground"]/img/@src']
        response_s = etree.HTML(resp)
        for i in img_url_list:
            image = response_s.xpath(i)
            if image:
                image_url = image[0]
                break
            else:
                image_url = None
        return image_url

    def downlad_img(self, asin, site='us'):
        for i in range(6):
            try:
                if site == 'us':
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                if site == "us":
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                elif site == 'uk':
                    asin_url = f'https://www.amazon.co.uk/dp/{asin}'  # 站点url
                elif site == 'de':
                    asin_url = f'https://www.amazon.de/dp/{asin}'
                elif site == 'fr':
                    asin_url = f'https://www.amazon.fr/dp/{asin}'
                elif site == 'es':
                    asin_url = f'https://www.amazon.es/dp/{asin}'
                elif site == 'it':
                    asin_url = f'https://www.amazon.it/dp/{asin}'
                headers = {
                    'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
                    'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'zh-CN,zh;q=0.9',
                    'Cache-Control': 'no-cache',
                    'Pragma': 'no-cache',
                    'Sec-Ch-Ua-Mobile': '?0', 'Sec-Ch-Ua-Platform': '"Windows"',
                    'Sec-Ch-Ua-Platform-Version': '"10.0.0"',
                    'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none',
                    'Sec-Fetch-User': '?1',
                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36',
                    'Viewport-Width': '1920'}
                session = requests.Session()
                print('asin_url::', asin_url)
                random_num = random.randint(1, 2)
                if random_num == 1:
                    resp = session.get(asin_url, headers=headers, timeout=10, verify=False,
                                       impersonate="chrome110")  # 获取网页
                else:
                    resp = requests2.get(asin_url, headers=headers, timeout=10, verify=False)
                image_url = self.xpath_imgurl(resp.text)
                if image_url:
                    str_ = re.findall(r'\._(.*)\.jpg', image_url)[0]
                    _url = image_url.replace('._' + str_, '')
                    print('获取图片url:', _url)
                    resp_img = session.get(_url, headers=headers, timeout=10, verify=False,
                                           impersonate="chrome110")  # 获取网页
                    asin_upper = asin.upper()
                    path_1 = fr"/mnt/data/img_data/tmp_img/"
                    if os.path.exists(path_1) == False:  # 判断路径是否存在
                        os.makedirs(path_1)
                    with open(rf"{path_1}{asin_upper}.jpg",
                              'wb') as f:  # 打开写入到path路径里-二进制文件,返回的句柄名为f
                        f.write(resp_img.content)  # 往f里写入r对象的二进制文件
                    return {"state": "ok"}
            except Exception as e:
                print(e, f"\n{traceback.format_exc()}")
        return {'state': 'error'}

    def download_img(self, asin, site='us'):
        for i in range(5):
            try:
                if site == 'us':
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                if site == "us":
                    asin_url = f'https://www.amazon.com/dp/{asin}'
                elif site == 'uk':
                    asin_url = f'https://www.amazon.co.uk/dp/{asin}'  # 站点url
                elif site == 'de':
                    asin_url = f'https://www.amazon.de/dp/{asin}'
                elif site == 'fr':
                    asin_url = f'https://www.amazon.fr/dp/{asin}'
                elif site == 'es':
                    asin_url = f'https://www.amazon.es/dp/{asin}'
                elif site == 'it':
                    asin_url = f'https://www.amazon.it/dp/{asin}'
                headers = {
                    'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
                    'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'zh-CN,zh;q=0.9',
                    'Cache-Control': 'no-cache',
                    'Pragma': 'no-cache',
                    'Sec-Ch-Ua-Mobile': '?0', 'Sec-Ch-Ua-Platform': '"Windows"',
                    'Sec-Ch-Ua-Platform-Version': '"10.0.0"',
                    'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none',
                    'Sec-Fetch-User': '?1',
                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36',
                    'Viewport-Width': '1920'}
                curl = Curl(cacert="/path/to/your/cert")
                session = requests.Session(curl=curl)
                print('asin_url::', asin_url)
                resp = session.get(asin_url, headers=headers, timeout=10, verify=False, impersonate="chrome110")  # 获取网页
                image_url = self.xpath_imgurl(resp.text)
                if image_url:
                    str_ = re.findall(r'\._(.*)\.jpg', image_url)[0]
                    _url = image_url.replace('._' + str_, '')
                    print('获取图片url:', _url)
                    resp_img = session.get(_url, headers=headers, timeout=10, verify=False,
                                           impersonate="chrome110")  # 获取网页
                    asin_upper = asin.upper()
                    path_1 = fr"/mnt/data/img_data/tmp_img/"
                    if os.path.exists(path_1) == False:  # 判断路径是否存在
                        os.makedirs(path_1)
                    asin_path = rf"{path_1}{asin_upper}.jpg"
                    with open(asin_path, 'wb') as f:  # 打开写入到path路径里-二进制文件,返回的句柄名为f
                        f.write(resp_img.content)  # 往f里写入r对象的二进制文件
                    query_vec = self.vgg_model.vgg_extract_feat(file=asin_path, file_type='bytes')
                    return query_vec
            except Exception as e:
                print(e, f"\n{traceback.format_exc()}")
        # return query_vec

    def load_index(self):
        my_index = faiss.read_index(self.index_path)
        print("加载索引完成, type(my_index):", type(my_index))
        return my_index

    def get_features(self, search_key, search_value, file):
        print("search_key, search_value, file:", search_key, search_value, file)
        if search_key == 'asin':
            return self.get_asin_features(search_value)
        elif search_key == 'file':
            return self.vgg_model.vgg_extract_feat(file=file, file_type='bytes')
        else:
            print("不符合的查询方式,仅支持asin和图片file两种方式")

    def get_asin_features(self, search_value):
        sql = f"select * from {self.site_name}_pictures_features where asin='{search_value}'"
        print("sql:", sql)
        df = pd.read_sql(sql, con=self.engine_pg)
        if df.shape[0] == 0:
            query_vector = self.download_img(search_value)
        # df = pd.read_sql(sql, con=self.engine)
        else:
            query_vector = eval(list(df.features)[0])
        # print("query_vector:", query_vector[:10])
        return query_vector

    def get_asin_by_index(self, site_name, indices_list):
        sql = f"select distinct(asin) as asin from {site_name}_pictures_id_index where indice in {tuple(indices_list)}"
        print("sql:", sql)
        df = pd.read_sql(sql=sql, con=self.engine_pg)
        asin_list = list(df.asin)
        return asin_list

    def calculate_similarity(self, site_name, indices_list, query_vector):
        # sql = f"select asin, features from {site_name}_pictures_features where asin in {tuple(asin_list)}"
        # sql = f"""
        #     select asin, features from {site_name}_pictures_features main
        #     left join (select distinct(asin) as asin from {site_name}_pictures_id_index where index in {tuple(indices_list)}) a
        #     on main.asin=a.asin
        # """
        sql = f"""    
            SELECT main.asin, a.features FROM  
            (select asin from {site_name}_pictures_id_index where index in {indices_list}) main 
            inner join {site_name}_pictures_features a 
            on main.asin=a.asin

        """
        df = pd.read_sql(sql, con=self.engine_pg)
        df = df.drop_duplicates(['asin'])
        all_vecs_dict = {asin: eval(features) for asin, features in zip(df.asin, df.features)}
        # query_vector = query_vector.T
        similarities = self.cosine_similarity_matrix(query_vec=query_vector, all_vecs=list(all_vecs_dict.values()))
        # print(similarities)
        similarities_dict = {asin: (similarity, f"http://{self.server_ip}/images/{site_name}/{asin[:1]}/{asin[:2]}/{asin[:3]}/{asin[:4]}/{asin[:5]}/{asin[:6]}/{asin}.jpg") for similarity, asin in zip(similarities, df.asin)}
        # print(similarities_dict)
        # for asin in asin_list:
        #     print(asin, similarities_dict[asin])
        # asin_similarities = {search_value: similarities_dict}
        # print("asin_similarities:", asin_similarities)
        return similarities_dict

    def cosine_similarity_matrix(self, query_vec, all_vecs):
        # 计算相似度
        query_vec_norm = np.linalg.norm(query_vec)
        all_vecs_norm = np.linalg.norm(all_vecs, axis=1)

        print(query_vec_norm.shape)
        print(all_vecs_norm.shape)

        dot_products = np.dot(all_vecs, query_vec)
        similarities = dot_products / (query_vec_norm * all_vecs_norm)
        # 将相似度转换为百分比
        similarities_percentage = similarities * 100
        # 保留所需的小数位数,例如保留两位小数
        similarities_percentage = np.round(similarities_percentage, 2)
        return similarities_percentage

    def search_api(self, site_name, search_key, search_value, top_k, file):
        # 获取向量
        query_vector = self.get_features(search_key, search_value, file)
        self.query_vector = np.array(query_vector)
        self.query_vector = self.query_vector.reshape(1, -1)
        distances, indices = self.my_index.search(self.query_vector, top_k)
        indices_list = tuple(indices.tolist()[0])
        # print("indices.tolist():", indices.tolist())
        # print("len(indices.tolist()):", len(indices.tolist()))

        # distances_list = distances.tolist()

        # print("indices_list:", type(indices_list), len(indices_list), indices_list)
        # print("distances_list:", type(distances_list), len(distances_list), distances_list)
        # asin_list = self.get_asin_by_index(site_name=site_name, indices_list=indices_list[0])
        # asin_path_list = self.get_asin_path_by_asin(site_name=site_name, asin_list=asin_list)
        # print("asin_list:", asin_list[:10])
        # print("asin_path_list:", asin_path_list)
        # 计算相似度
        similarities_dict = self.calculate_similarity(site_name, indices_list, query_vector=query_vector)
        # print("similarities_dict:", similarities_dict)

        # return indices_list, distances_list, asin_list, similarities_dict
        return similarities_dict

        # print("type(indices), type(distances):", type(indices), type(distances))
        # print("indices:", indices[:20])
        # print(f"Top {self.top_k} elements in the dataset for max inner product search:")
        # for i, (dist, indice) in enumerate(zip(distances[0], indices[0])):
        #     print(f"{i + 1}: Vector number {indice:4} with distance {dist}")

    def search(self):
        self.query_vector = np.array(self.query_vector)
        self.query_vector = self.query_vector.reshape(1, -1)
        distances, indices = self.my_index.search(self.query_vector, self.top_k)
        print("type(indices), type(distances):", type(indices), type(distances))
        print("indices:", indices[:20])
        print(f"Top {self.top_k} elements in the dataset for max inner product search:")
        for i, (dist, indice) in enumerate(zip(distances[0], indices[0])):
            print(f"{i + 1}: Vector number {indice:4} with distance {dist}")

    def run(self):
        self.search()


if __name__ == '__main__':
    site_name = sys.argv[1]  # 参数1:站点
    search_key = sys.argv[2]  # 参数2:查询类型
    search_value = sys.argv[3]  # 参数3:查询值
    top_k = int(sys.argv[4])  # 参数4:最相似的top前k个值

    handle_obj = PicturesSearch(site_name=site_name, search_key=search_key, search_value=search_value, top_k=top_k)
    handle_obj.run()