import os import sys import faiss import numpy as np import pandas as pd os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" sys.path.append(os.path.dirname(sys.path[0])) # 上级目录 from utils.templates_mysql import TemplatesMysql from vgg_model import VGGNet from curl_cffi import requests, Curl import requests as requests2 import random from lxml import etree import traceback import os import re from utils.db_util import DbTypes, DBUtil class ImageSearchClass(object): def __init__(self, site_name='us', search_key='asin', search_value='B0BBNQCXZL', top_k=100, asin_type=1): self.site_name = site_name self.search_key = search_key self.search_value = search_value self.top_k = top_k self.asin_type = asin_type self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name) self.vgg_model = VGGNet() self.query_vector = np.array # self.engine = TemplatesMysql().engine # self.server_ip = "192.168.200.210" self.server_ip = "113.100.143.162:8000" self.index_path, self.image_path = self.get_knn_index_path() self.my_index = self.load_index() def get_knn_index_path(self): if self.asin_type == 1: index_path = rf"/mnt/data/img_data/image_index/amazon_self/knn.index" image_path = rf"http://{self.server_ip}/images/amazon_self/{self.site_name}" elif self.asin_type == 2: index_path = rf"/mnt/data/img_data/image_index/amazon/knn.index" image_path = rf"http://{self.server_ip}/images/amazon/{self.site_name}" else: index_path = '' image_path = '' return index_path, image_path def xpath_imgurl(self, resp): print(2222222222222222222) img_url_list = ["//div[@id='imgTagWrapperId']/img/@src", '//div[@id="img-canvas"]/img/@src', '//div[@id="imgTagWrapperId"]/img/@src', '//div[@id="img-canvas"]/img/@src', '//div[@class="image-wrapper"]/img/@src', '//div[@id="mainImageContainer"]/img/@src', '//img[@id="js-masrw-main-image"]/@src', '//div[@id="a2s-dp-main-content"]//img/@src', '//div[@id="ppd-left"]//img/@src', '//div[@id="ebooks-img-canvas"]//img[@class="a-dynamic-image frontImage"]/@src', '//div[@id="ppd-left"]//img/@src', '//div[@class="a-column a-span12 a-text-center"]/img/@src', '//img[0]/@src', '//img[@id="seriesImageBlock"]/@src', '//img[@class="main-image"]/@src', '//img[@class="mainImage"]/@src', '//img[@id="gc-standard-design-image"]/@src', '//div[@class="a-row a-spacing-medium"]//img[1]/@src', '//div[@class="a-image-container a-dynamic-image-container greyBackground"]/img/@src'] response_s = etree.HTML(resp) for i in img_url_list: image = response_s.xpath(i) if image: image_url = image[0] break else: image_url = None return image_url def download_img(self, asin, site='us'): for i in range(5): try: if site == 'us': asin_url = f'https://www.amazon.com/dp/{asin}' if site == "us": asin_url = f'https://www.amazon.com/dp/{asin}' elif site == 'uk': asin_url = f'https://www.amazon.co.uk/dp/{asin}' # 站点url elif site == 'de': asin_url = f'https://www.amazon.de/dp/{asin}' elif site == 'fr': asin_url = f'https://www.amazon.fr/dp/{asin}' elif site == 'es': asin_url = f'https://www.amazon.es/dp/{asin}' elif site == 'it': asin_url = f'https://www.amazon.it/dp/{asin}' headers = { 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7', 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'zh-CN,zh;q=0.9', 'Cache-Control': 'no-cache', 'Pragma': 'no-cache', 'Sec-Ch-Ua-Mobile': '?0', 'Sec-Ch-Ua-Platform': '"Windows"', 'Sec-Ch-Ua-Platform-Version': '"10.0.0"', 'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none', 'Sec-Fetch-User': '?1', 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36', 'Viewport-Width': '1920'} curl = Curl(cacert="/path/to/your/cert") session = requests.Session(curl=curl) print('asin_url::', asin_url) resp = session.get(asin_url, headers=headers, timeout=10, verify=False, impersonate="chrome110") # 获取网页 image_url = self.xpath_imgurl(resp.text) if image_url: str_ = re.findall(r'\._(.*)\.jpg', image_url)[0] _url = image_url.replace('._' + str_, '') print('获取图片url:', _url) resp_img = session.get(_url, headers=headers, timeout=10, verify=False, impersonate="chrome110") # 获取网页 asin_upper = asin.upper() path_1 = fr"/mnt/data/img_data/tmp_img/" if os.path.exists(path_1) == False: # 判断路径是否存在 os.makedirs(path_1) asin_path = rf"{path_1}{asin_upper}.jpg" with open(asin_path, 'wb') as f: # 打开写入到path路径里-二进制文件,返回的句柄名为f f.write(resp_img.content) # 往f里写入r对象的二进制文件 query_vec = self.vgg_model.vgg_extract_feat(file=asin_path, file_type='bytes') return query_vec except Exception as e: print(e, f"\n{traceback.format_exc()}") # return query_vec def load_index(self): my_index = faiss.read_index(self.index_path) print("加载索引完成, type(my_index):", type(my_index)) return my_index def get_features(self, search_key, search_value, file_path): print("search_key, search_value, file_path:", search_key, search_value, file_path) if search_key == 'asin': return self.get_asin_features(search_value) elif search_key == 'file': return self.vgg_model.vgg_extract_feat(file=file_path, file_type='bytes') else: print("不符合的查询方式,仅支持asin和图片file两种方式") def get_asin_features(self, search_value): sql = f"select * from asin_image_features where site_name='{self.site_name}' and asin_type='{self.asin_type}' and asin='{search_value}'" print("sql:", sql) df = pd.read_sql(sql, con=self.engine_srs) if df.shape[0] == 0: query_vector = self.download_img(search_value) # df = pd.read_sql(sql, con=self.engine) else: query_vector = eval(list(df.features)[0]) # print("query_vector:", query_vector[:10]) return query_vector def get_asin_by_index(self, site_name, indices_list): sql = f"select distinct(asin) as asin from image_id_index where site_name='{self.site_name}' and asin_type='{self.asin_type}' and indice in {tuple(indices_list)}" print("sql:", sql) df = pd.read_sql(sql=sql, con=self.engine_srs) asin_list = list(df.asin) return asin_list @staticmethod def sorted_dict(similarities_dict): # 按相似度数值降序排序 sorted_data = sorted(similarities_dict.items(), key=lambda item: item[1][0], reverse=True) # 添加顺序列 ranked_data = {key: value + [rank + 1] for rank, (key, value) in enumerate(sorted_data)} return ranked_data def calculate_similarity(self, site_name, indices_list, query_vector, asin_type=1): sql = f""" SELECT main.asin, a.features FROM (select asin from image_id_index where site_name='{site_name}' and asin_type='{asin_type}' and indice in {indices_list}) main inner join asin_image_features a on main.asin=a.asin """ df = pd.read_sql(sql, con=self.engine_srs) df = df.drop_duplicates(['asin']) all_vecs_dict = {asin: eval(features) for asin, features in zip(df.asin, df.features)} # query_vector = query_vector.T similarities = self.cosine_similarity_matrix(query_vec=query_vector, all_vecs=list(all_vecs_dict.values())) # print(similarities) similarities_dict = {asin: (similarity, f"{self.image_path}/{asin[:1]}/{asin[:2]}/{asin[:3]}/{asin[:4]}/{asin[:5]}/{asin[:6]}/{asin}.jpg") for similarity, asin in zip(similarities, df.asin)} # print(similarities_dict) # for asin in asin_list: # print(asin, similarities_dict[asin]) # asin_similarities = {search_value: similarities_dict} # print("asin_similarities:", asin_similarities) # similarities_dict = self.sorted_dict(similarities_dict=similarities_dict) # 按 similarity 降序排序并添加顺序列 sorted_data = sorted(similarities_dict.items(), key=lambda item: item[1][0], reverse=True) similarities_dict = {key: value + (rank + 1,) for rank, (key, value) in enumerate(sorted_data)} return similarities_dict def cosine_similarity_matrix(self, query_vec, all_vecs): # 计算相似度 query_vec_norm = np.linalg.norm(query_vec) all_vecs_norm = np.linalg.norm(all_vecs, axis=1) print(query_vec_norm.shape) print(all_vecs_norm.shape) dot_products = np.dot(all_vecs, query_vec) similarities = dot_products / (query_vec_norm * all_vecs_norm) # 将相似度转换为百分比 similarities_percentage = similarities * 100 # 保留所需的小数位数,例如保留两位小数 similarities_percentage = np.round(similarities_percentage, 2) return similarities_percentage def search_api(self, site_name, search_key, search_value, top_k, file_path): # 获取向量 query_vector = self.get_features(search_key=search_key, search_value=search_value, file_path=file_path) self.query_vector = np.array(query_vector) self.query_vector = self.query_vector.reshape(1, -1) distances, indices = self.my_index.search(self.query_vector, top_k) indices_list = tuple(indices.tolist()[0]) # print("indices.tolist():", indices.tolist()) # print("len(indices.tolist()):", len(indices.tolist())) # distances_list = distances.tolist() # print("indices_list:", type(indices_list), len(indices_list), indices_list) # print("distances_list:", type(distances_list), len(distances_list), distances_list) # asin_list = self.get_asin_by_index(site_name=site_name, indices_list=indices_list[0]) # asin_path_list = self.get_asin_path_by_asin(site_name=site_name, asin_list=asin_list) # print("asin_list:", asin_list[:10]) # print("asin_path_list:", asin_path_list) # 计算相似度 similarities_dict = self.calculate_similarity(site_name, indices_list, query_vector=query_vector) # print("similarities_dict:", similarities_dict) # return indices_list, distances_list, asin_list, similarities_dict return similarities_dict # print("type(indices), type(distances):", type(indices), type(distances)) # print("indices:", indices[:20]) # print(f"Top {self.top_k} elements in the dataset for max inner product search:") # for i, (dist, indice) in enumerate(zip(distances[0], indices[0])): # print(f"{i + 1}: Vector number {indice:4} with distance {dist}") def search(self): self.query_vector = np.array(self.query_vector) self.query_vector = self.query_vector.reshape(1, -1) distances, indices = self.my_index.search(self.query_vector, self.top_k) print("type(indices), type(distances):", type(indices), type(distances)) print("indices:", indices[:20]) print(f"Top {self.top_k} elements in the dataset for max inner product search:") for i, (dist, indice) in enumerate(zip(distances[0], indices[0])): print(f"{i + 1}: Vector number {indice:4} with distance {dist}") def run(self): self.search() if __name__ == '__main__': site_name = sys.argv[1] # 参数1:站点 search_key = sys.argv[2] # 参数2:查询类型 search_value = sys.argv[3] # 参数3:查询值 top_k = int(sys.argv[4]) # 参数4:最相似的top前k个值 handle_obj = ImageSearchClass(site_name=site_name, search_key=search_key, search_value=search_value, top_k=top_k) handle_obj.run()