Commit f7c24024 by fangxingjun

no message

parent 781a04a7
import os
import sys
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
from utils.db_util import DbTypes, DBUtil
class ImgAlterTableName(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgAlterTableName, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
def read_data(self):
pass
def handle_data(self):
with self.engine_doris.begin() as conn:
sql1 = "ALTER TABLE img_id_index RENAME img_id_index_temp;"
conn.execute(sql1)
sql2 = "ALTER TABLE img_id_index_copy RENAME img_id_index;"
conn.execute(sql2)
sql3 = "ALTER TABLE img_id_index_temp RENAME img_id_index_copy;"
conn.execute(sql3)
print(f"交换表名称完成--sql1: {sql1}\nsql2: {sql2}\nsql3: {sql3}")
def save_data(self):
pass
if __name__ == '__main__':
site_name = sys.argv[1]
img_type = sys.argv[2]
handle_obj = ImgAlterTableName(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
from autofaiss import build_index
from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel
from pyspark import SparkConf, SparkContext
def create_spark_session():
# this must be a path that is available on all worker nodes
# os.environ['PYSPARK_PYTHON'] = "/opt/module/spark/demo/py_demo/img_search/autofaiss.pex"
spark = (
SparkSession.builder
.config("spark.executorEnv.PEX_ROOT", "./.pex")
.config("spark.executor.cores", "4")
.config("spark.executor.memory", "20G") # make sure to increase this if you're using more cores per executor
.config("spark.num.executors", "10")
.config("spark.yarn.queue", "spark")
.master("local") # this should point to your master node, if using the tunnelling version, keep this to localhost
.appName("autofaiss-create-index")
.getOrCreate()
)
return spark
spark = create_spark_session()
index, index_infos = build_index(
# embeddings="hdfs://nameservice1:8020/home/img_search/us/amazon_inv/parquet",
embeddings="hdfs://nameservice1:8020/home/img_search/img_parquet/us/amazon_inv",
distributed="pyspark",
file_format="parquet",
max_index_memory_usage="80G", # 16G
current_memory_available="120G", # 24G
temporary_indices_folder="hdfs://nameservice1:8020/home/img_search/img_tmp/us/amazon_inv//distributed_autofaiss_indices",
index_path="hdfs://nameservice1:8020/home/img_search/img_index/us/amazon_inv/knn.index",
index_infos_path="hdfs://nameservice1:8020/home/img_search/img_index/us/amazon_inv/infos.json",
)
print("index, index_infos:", index, index_infos)
import ast
import os
import sys
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
class PicturesDimFeaturesSlice(Templates):
def __init__(self, site_name='us', img_type='amazon_inv'):
super(PicturesDimFeaturesSlice, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.db_save = f'img_dim_features_slice'
self.spark = self.create_spark_object(
app_name=f"{self.db_save}: {self.site_name}")
self.df_asin_features = self.spark.sql(f"select 1+1;")
self.df_save = self.spark.sql(f"select 1+1;")
# self.partitions_by = ['site_name', 'block']
self.partitions_by = ['site_name', 'img_type']
self.partitions_num = 10
def read_data(self):
# sql = f"select id, asin, img_vector as embedding from ods_asin_extract_features;"
sql = f"select id, img_unique, features, img_type from img_ods_features where site_name='{self.site_name}' and img_type='{self.img_type}';"
print("sql:", sql)
self.df_save = self.spark.sql(sql).cache()
self.df_save.show(10)
print(f"self.df_save.count(): {self.df_save.count()}")
# 由于不需要在这一步生成array类型
# partitions_num = self.df_asin_features.rdd.getNumPartitions()
# print("分区数量:", partitions_num) # 642
# # self.partitions_num = 1000
# self.df_save = self.df_save.repartition(self.partitions_num)
# print("重置分区数量:", self.partitions_num) # 642
def handle_data(self):
# 定义一个将字符串转换为列表的UDF
# str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
# # 对DataFrame中的列应用这个UDF
# self.df_save = self.df_save.withColumn("embedding", str_to_list_udf(self.df_save["embedding"]))
self.df_save = self.df_save.withColumn('site_name', F.lit(self.site_name))
if __name__ == '__main__':
site_name = sys.argv[1] # 参数1:站点
handle_obj = PicturesDimFeaturesSlice(site_name=site_name)
handle_obj.run()
import os
import socket
import sys
import threading
import logging
import time
import traceback
import uuid
import pandas as pd
import redis
import requests
from sqlalchemy import text
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.db_util import DbTypes, DBUtil
class ImgDownload(object):
def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10, limit=200):
self.site_name = site_name
self.img_type = img_type
self.thread_num = thread_num
self.limit = limit
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.hostname = socket.gethostname()
self.first_local_dir, self.read_table = self.get_first_local_dir()
# self.read_table = f"{self.site_name}_inv_img_info"
self.local_name = self.read_table
def get_first_local_dir(self):
if self.img_type == 'amazon_self':
first_local_dir = f"/mnt/data/img_data/amazon_self/{self.site_name}"
image_table = f'{self.site_name}_self_asin_image'
elif self.img_type == 'amazon':
first_local_dir = f"/mnt/data/img_data/amazon/{self.site_name}"
image_table = f'{self.site_name}_amazon_image'
elif self.img_type == 'amazon_inv':
first_local_dir = f"/mnt/data/img_data/amazon_inv/{self.site_name}"
image_table = f'{self.site_name}_inv_img_info'
else:
first_local_dir = ""
image_table = ""
return first_local_dir, image_table
def acquire_lock(self, lock_name, timeout=100):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
@staticmethod
def img_download(img_url, img_path, img_name):
file_path = f"{img_path}{img_name}"
for d_num in range(5):
try:
response = requests.get(img_url)
if response.status_code == 200:
# Open a file in binary write mode
with open(file_path, 'wb') as file:
file.write(response.content)
# print("Image downloaded successfully.")
return True
else:
continue
except Exception as e:
error = "No such file or directory"
if error in str(e):
os.makedirs(img_path)
print(f"{d_num}次--下载图片失败, 图片路径: {file_path}, 图片url: {img_url}, \n错误信息: {e, traceback.format_exc()}")
time.sleep(2)
return False
def update_state(self, id_list, state, state_value="success"):
if id_list:
while True:
try:
with self.engine_mysql.begin() as conn:
id_tuple = tuple(id_list)
print(f"{state_value}--id_tuple: {len(id_tuple)}, {id_tuple[:10]}", )
if id_tuple:
id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
sql_update = f"UPDATE {self.read_table} SET state={state} WHERE id IN {id_tuple_str};"
print("sql_update:", sql_update[:150])
conn.execute(sql_update)
break
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(20)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
continue
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
print("self.hostname:", self.hostname)
with self.engine_mysql.begin() as conn:
sql_read = text(f"SELECT id, img_id, img_type, img_url, id_segment FROM {self.read_table} WHERE state=1 LIMIT {self.limit};")
df = pd.read_sql(sql=sql_read, con=self.engine_mysql)
id_tuple = tuple(df.id)
print(f"sql_read: {sql_read}, {df.shape}", id_tuple[:10])
if id_tuple:
id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
sql_update = f"UPDATE {self.read_table} SET state=2 WHERE id IN {id_tuple_str};"
print("sql_update:", sql_update[:150])
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(10) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(20)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
continue
def handle_data(self, df, thread_id):
# 1. 下载图片
img_success_id_list = []
img_failed_id_list = []
id_list = list(df.id)
id_len = len(id_list)
for id_segment, id, img_id, img_type, img_url in zip(df.id_segment, df.id, df.img_id, df.img_type, df.img_url):
img_path = f"{self.first_local_dir}/{id_segment}/"
img_name = f"{id_segment}_{id}_{img_id}_{img_type}.jpg"
if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
img_path = img_path.replace("/mnt", "/home")
d_flag = self.img_download(img_url=img_url, img_path=img_path, img_name=img_name)
id_index = id_list.index(id)
print(f"self.hostname: {self.hostname}, 线程: {thread_id}, 是否成功: {d_flag}, id_index: {id_index}, 进度: {round(id_index/id_len * 100, 2)}%, img_path: {img_path}{img_name}")
if d_flag:
img_success_id_list.append(id)
else:
img_failed_id_list.append(id)
# 2. 更改状态 -- 成功3 失败4
print(f"success: {len(img_success_id_list)}, failed: {len(img_failed_id_list)}")
self.update_state(id_list=img_success_id_list, state=3, state_value="success")
self.update_state(id_list=img_failed_id_list, state=4, state_value="failed")
def save_data(self):
pass
def run(self, thread_id=1):
while True:
try:
df = self.read_data()
if df.shape[0]:
self.handle_data(df=df, thread_id=thread_id)
self.save_data()
# break
else:
break
except Exception as e:
print(e, traceback.format_exc())
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
time.sleep(20)
continue
def run_thread(self):
logging.info("所有线程处理开始")
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run, args=(thread_id, ))
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
# handle_obj = PicturesFeatures(self_flag='_self')
# site_name = int(sys.argv[1]) # 参数1:站点
# site_name = 'us'
# img_type = "amazon_inv"
# limit = 100
# thread_num = 1
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
limit = int(sys.argv[3]) # 参数3:每次读取的数量--1000
thread_num = int(sys.argv[4]) # 参数4:线程数量--5
handle_obj = ImgDownload(site_name=site_name, img_type=img_type, thread_num=thread_num, limit=limit)
# handle_obj.run()
handle_obj.run_thread()
\ No newline at end of file
import ast
import logging
import os
import re
import sys
import threading
import time
import traceback
import pandas as pd
import redis
from pyspark.sql.types import ArrayType, FloatType
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing
from utils.db_util import DbTypes, DBUtil
class PicturesIdIndex(Templates):
def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10):
super(PicturesIdIndex, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.thread_num = thread_num
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.db_save = f'img_dwd_id_index'
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_features = self.spark.sql(f"select 1+1;")
self.df_save = self.spark.sql(f"select 1+1;")
self.df_save_local = self.spark.sql(f"select 1+1;")
self.partitions_by = ['site_name', 'img_type', 'block_name']
self.partitions_num = 1
# mysql表获取的相关变量
self.tn_pics_hdfs_index = f"img_hdfs_index"
self.id = int()
self.current_counts = int()
self.all_counts = int()
self.hdfs_path = str()
self.hdfs_block_name = str()
self.local_path = str()
self.local_name = self.db_save
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
# 读取mysql的hdfs_path路径信息
with self.engine_doris.begin() as conn:
sql_read = text(f"SELECT * FROM {self.tn_pics_hdfs_index} WHERE state=1 and site_name='{self.site_name}' and img_type='{self.img_type}' LIMIT 1;")
print("sql_read:", sql_read)
result = conn.execute(sql_read)
df = pd.DataFrame(result.fetchall())
if df.shape[0]:
df.columns = result.keys()
self.id = list(df.id)[0] if list(df.id) else None
self.current_counts = list(df.current_counts)[0] if list(df.current_counts) else None
self.all_counts = list(df.all_counts)[0] if list(df.all_counts) else None
self.hdfs_path = list(df.hdfs_path)[0] if list(df.hdfs_path) else None
self.hdfs_block_name = re.findall("(part-\d+)-", self.hdfs_path)[0] if self.hdfs_path else None
self.local_path = rf"/mnt/data/img_data/img_parquet/{self.site_name}/{self.img_type}/{self.hdfs_block_name}"
print(f"df.shape:{df.shape}, self.id:{self.id}, self.current_counts:{self.current_counts}, self.all_counts:{self.all_counts}, self.hdfs_path:{self.hdfs_path}")
if self.id:
os.system(f"hdfs dfs -rm -r /home/big_data_selection/dwd/img_dwd_id_index/site_name={self.site_name}/img_type={self.img_type}/block_name={self.hdfs_block_name}")
sql_update = text(
f"UPDATE {self.tn_pics_hdfs_index} SET state=2 WHERE id={self.id};")
print("sql_update:", sql_update)
conn.execute(sql_update)
else:
quit()
# 读取hdfs路径下的parquet文件
self.df_features = self.spark.read.text(self.hdfs_path).cache()
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(5) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(5)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
continue
def handle_data(self):
# 创建一个新的 DataFrame,其中每个字段都是一个独立的列
split_df = self.df_features.select(F.split(self.df_features['value'], '\t').alias('split_values'))
# 假设你知道你的数据有三个字段
# 你可以这样创建每个字段的独立列
final_df = split_df.select(
split_df['split_values'].getItem(0).alias('id'),
split_df['split_values'].getItem(1).alias('img_unique'),
split_df['split_values'].getItem(2).alias('features')
)
print("分块前分区数量:", final_df.rdd.getNumPartitions())
final_df.show(10)
# 从hdfs读取parquet文件,进行split切分的时候是字符串类型-->转换成数值类型
final_df = final_df.withColumn('id', final_df['id'].cast('bigint')) # 然后你可以安全地转换
# 添加索引列
final_df = final_df.withColumn("index", F.monotonically_increasing_id() + self.all_counts)
final_df.show()
# 此处更改数据类型在pictures_dim_features_slice已经做了 -- 由于读取的是lzo文件,而不是直接读表,因此还需要转换类型
# 定义一个将字符串转换为列表的UDF
str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
# # 对DataFrame中的列应用这个UDF
final_df = final_df.withColumn("features", str_to_list_udf(final_df["features"]))
# 使用 Spark 内置的 split 函数来将字符串转换为数组
# final_df = final_df.withColumn("features", F.split(F.col("features"), ",").cast(ArrayType(FloatType())))
final_df = final_df.withColumnRenamed("features", "embedding") # 必须要用这个字段 -- 后面需要建立图片索引
# final_df.write.mode('overwrite').parquet("hdfs://hadoop5:8020/home/ffman/parquet")
# final_df = final_df.withColumn("block", F.lit(self.hdfs_block_name))
self.df_save = final_df.withColumn("site_name", F.lit(self.site_name))
self.df_save = self.df_save.withColumn("img_type", F.lit(self.img_type))
self.df_save = self.df_save.withColumn("block_name", F.lit(self.hdfs_block_name))
self.df_save.show(10)
# self.df_save_local = self.df_save.select("features") # features
self.df_save_local = self.df_save.select("embedding") # features
def save_data_local(self):
print("当前存储到本地:", self.local_path)
if os.path.exists(self.local_path):
os.system(f"rm -rf {self.local_path}")
os.makedirs(self.local_path)
# Convert DataFrame to Arrow Table
df = self.df_save_local.toPandas()
table = pa.Table.from_pandas(df)
# Save to Parquet
pq.write_table(table, f"{self.local_path}/{self.hdfs_block_name}.parquet")
def update_state_after_save(self):
with self.engine_doris.begin() as conn:
sql_update = text(
f"UPDATE {self.tn_pics_hdfs_index} SET state=3 WHERE state=2 and id={self.id};")
print("sql_update:", sql_update)
conn.execute(sql_update)
def run(self):
while True:
try:
df = self.read_data()
if df.shape[0]:
self.handle_data()
self.save_data()
# 存储到本地
self.save_data_local()
self.update_state_after_save()
break
else:
break
except Exception as e:
print(f"error: {e}", traceback.format_exc())
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
time.sleep(10)
def run_thread(self):
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run)
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
# def main():
# handle_obj = PicturesIdIndex()
# handle_obj.run()
if __name__ == "__main__":
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数1:图片类型来源
# thread_num = int(sys.argv[3]) # 参数3:线程数量 -- 1
thread_num = 1
handle_obj = PicturesIdIndex(site_name=site_name, img_type=img_type, thread_num=thread_num)
handle_obj.run()
# handle_obj.run_thread()
quit()
processes = []
for _ in range(1): # 用于设定进程数量
handle_obj = PicturesIdIndex()
process = multiprocessing.Process(target=handle_obj.run)
process.start()
processes.append(process)
# 等待所有进程完成
for process in processes:
process.join()
# if __name__ == '__main__':
# # 设置进程数
# num_processes = 4 # 设置为你需要的进程数
# # 创建进程池对象
# pool = Pool(processes=num_processes)
# # 使用进程池中的进程执行任务
# pool.apply(main)
# # 关闭进程池
# pool.close()
# # 等待所有进程完成
# pool.join()
# if __name__ == '__main__':
# # 创建进程对象
# process = Process(target=main)
# # 启动进程
# process.start()
# # 等待进程结束
# process.join()
import multiprocessing
import os
import sys
import time
import traceback
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
from utils.db_util import DbTypes, DBUtil
class JudgeFinished(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(JudgeFinished, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.tn_pics_hdfs_index = f"img_hdfs_index"
def judge(self):
sql = f"select * from {self.tn_pics_hdfs_index} where state in (1, 2) and site_name='{self.site_name}' and img_type='{self.img_type}';"
df = pd.read_sql(sql, con=self.engine_doris)
print(f"sql: {sql}, {df.shape}")
result_flag = True if df.shape[0] else False
return result_flag
def main(site_name='us', img_type='amazon_inv', p_num=0):
while True:
try:
judge_obj = JudgeFinished(site_name=site_name, img_type=img_type)
result_flag = judge_obj.judge()
if result_flag:
print(f"继续, result_flag: {result_flag}")
os.system(f"/opt/module/spark/bin/spark-submit --master yarn --driver-memory 1g --executor-memory 4g --executor-cores 1 --num-executors 1 --queue spark /opt/module/spark/demo/py_demo/img_search/img_dwd_id_index.py {site_name} {img_type}")
else:
print(f"结束, result_flag: {result_flag}")
break
except Exception as e:
print(e, traceback.format_exc())
time.sleep(20)
error = "ValueError: Length mismatch: Expected axis has 0 elements"
if error in e:
print(f"当前已经跑完所有block块id对应的index关系,退出进程-{p_num}")
quit()
continue
if __name__ == "__main__":
site_name = sys.argv[1]
img_type = sys.argv[2]
process_num = int(sys.argv[3]) # 参数1:进程数
processes = []
for p_num in range(process_num): # 用于设定进程数量
process = multiprocessing.Process(target=main, args=(site_name, img_type, p_num))
process.start()
processes.append(process)
# 等待所有进程完成
for process in processes:
process.join()
import os
import sys
import threading
import time
import traceback
import socket
import uuid
import numpy as np
import pandas as pd
import redis
import logging
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
# from utils.templates import Templates
from sqlalchemy import text
from vgg_model import VGGNet
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
class ImgExtractFeatures(object):
def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10, limit=1000):
# super(ImgFeatures, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.thread_num = thread_num
self.limit = limit
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.local_name = f"{self.site_name}_img_features"
self.vgg_model = VGGNet()
self.hostname = socket.gethostname()
self.read_table = f"img_local_path"
self.save_table = f"img_features"
def acquire_lock(self, lock_name, timeout=100):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
print("self.hostname:", self.hostname)
with self.engine_doris.begin() as conn:
sql_read = text(f"SELECT id, img_unique, local_path, img_type FROM selection.{self.read_table} WHERE site_name='{self.site_name}' and img_type='{self.img_type}' and state=1 LIMIT {self.limit};")
# result = conn.execute(sql_read)
# df = pd.DataFrame(result.fetchall())
df = pd.read_sql(sql=sql_read, con=self.engine_doris)
img_unique_tuple = tuple(df.img_unique)
print(f"sql_read: {sql_read}, {df.shape}", img_unique_tuple[:10])
if img_unique_tuple:
img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
sql_update = text(f"UPDATE selection.{self.read_table} SET state=2 WHERE img_unique IN {img_unique_tuple_str};")
print("sql_update:", sql_update)
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(5) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(5)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
continue
def handle_data(self, df, thread_id):
id_list = list(df.id)
img_unique_list = list(df.img_unique)
local_path_list = list(df.local_path)
data_list = []
for id, img_unique, local_path in zip(id_list, img_unique_list, local_path_list):
index = id_list.index(id)
print(f"thread_id, index, id, img_unique, local_path: {thread_id, index, id, img_unique, local_path}")
if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
local_path = local_path.replace("/mnt", "/home")
try:
features = self.vgg_model.vgg_extract_feat(file=local_path)
except Exception as e:
print(e, traceback.format_exc())
features = list(np.zeros(shape=(512,)))
data_list.append([id, img_unique, str(features), self.img_type, self.site_name])
columns = ['id', 'img_unique', 'features', 'img_type', 'site_name']
df_save = pd.DataFrame(data_list, columns=columns)
return df_save
def save_data(self, df):
df.to_sql(self.save_table, con=self.engine_doris, if_exists="append", index=False)
with self.engine_doris.begin() as conn:
img_unique_tuple = tuple(df.img_unique)
if img_unique_tuple:
img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
sql_update = f"update selection.{self.read_table} set state=3 where img_unique in {img_unique_tuple_str};"
print(f"sql_update: {sql_update}")
conn.execute(sql_update)
def run(self, thread_id=1):
while True:
try:
df = self.read_data()
if df.shape[0]:
df_save = self.handle_data(df=df, thread_id=thread_id)
self.save_data(df=df_save)
# break
else:
break
except Exception as e:
print(e, traceback.format_exc())
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.vgg_model = VGGNet()
time.sleep(20)
continue
def run_thread(self):
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run, args=(thread_id, ))
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
# handle_obj = PicturesFeatures(self_flag='_self')
# site_name = int(sys.argv[1]) # 参数1:站点
# site_name = 'us'
# img_type = "amazon_inv"
# limit = 100
# thread_num = 1
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
limit = int(sys.argv[3]) # 参数3:每次读取的数量--1000
thread_num = int(sys.argv[4]) # 参数4:线程数量--5
handle_obj = ImgExtractFeatures(site_name=site_name, img_type=img_type, thread_num=thread_num, limit=limit)
# handle_obj.run()
handle_obj.run_thread()
\ No newline at end of file
import os
import sys
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from utils.db_util import DbTypes, DBUtil
class ImgHdfsIndex(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgHdfsIndex, self).__init__()
self.site_name = site_name
self.img_type = img_type
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.db_save = f'img_hdfs_index'
self.img_dim_features_slice = f'img_dim_features_slice'
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_features = self.spark.sql(f"select 1+1;")
self.df_save = pd.DataFrame()
# self.hdfs_file_path = f'hdfs://nameservice1:8020/home/big_data_selection/dim/{self.img_dim_features_slice}/site_name={self.site_name}/img_type={self.img_type}/'
self.hdfs_file_path = f'hdfs://nameservice1:8020'
self.hdfs_file_list = self.get_hdfs_file_list()
self.index_count = 0
# def get_hdfs_file_list(self):
# # 导入hadoop的包
# java_import(self.spark._jvm, 'org.apache.hadoop.fs.Path')
# # fs = self.spark._jvm.org.apache.hadoop.fs.FileSystem.get(self.spark._jsc.hadoopConfiguration(self.hdfs_file_path))
# # status = fs.listStatus(self.spark._jvm.org.apache.hadoop.fs.Path())
#
# fs = self.spark._jvm.org.apache.hadoop.fs.FileSystem.get(self.spark._jsc.hadoopConfiguration())
# path = self.spark._jvm.org.apache.hadoop.fs.Path(self.hdfs_file_path)
# status = fs.listStatus(path)
#
# hdfs_file_list = [file_status.getPath().getName() for file_status in status]
# return hdfs_file_list
def get_hdfs_file_list(self):
# 使用 os.system 执行 hdfs dfs -ls 命令
command = f"hdfs dfs -ls /home/big_data_selection/dim/img_dim_features_slice/site_name={self.site_name}/img_type={self.img_type}"
result = os.popen(command).read()
# 解析命令输出
file_list = []
for line in result.split('\n'):
if line:
parts = line.split()
if len(parts) > 7:
file_path = parts[-1]
file_list.append(file_path)
print(f"file_list: {(len(file_list))}", file_list[:3])
return file_list
def read_data(self, hdfs_path):
df = self.spark.read.text(hdfs_path)
index_count = df.count()
return df, index_count
def handle_data(self):
pass
def save_data(self):
self.df_save.to_sql(self.db_save, con=self.engine_doris, if_exists="append", index=False)
def run(self):
data_list = []
for hdfs_file in self.hdfs_file_list:
index = self.hdfs_file_list.index(hdfs_file)
hdfs_path = self.hdfs_file_path + hdfs_file
df, index_count = self.read_data(hdfs_path)
data_list.append([index, hdfs_path, index_count, self.index_count])
print([index, hdfs_path, index_count, self.index_count])
self.index_count += index_count
self.df_save = pd.DataFrame(data=data_list, columns=['index', 'hdfs_path', 'current_counts', 'all_counts'])
self.df_save["site_name"] = self.site_name
self.df_save["img_type"] = self.img_type
self.df_save["id"] = self.df_save["index"] + 1
self.df_save["state"] = 1
# self.df_save.to_csv("/root/hdfs_parquet_block_info.csl")
self.save_data()
if __name__ == '__main__':
# site_name = 'us'
# img_type = 'amazon_inv'
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数1:图片类型来源
handle_obj = ImgHdfsIndex(site_name=site_name, img_type=img_type)
handle_obj.run()
import ast
import datetime
import logging
import os
import re
import sys
import threading
import time
import traceback
import pandas as pd
import redis
from pyspark.sql.types import ArrayType, FloatType
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing
from utils.db_util import DbTypes, DBUtil
from utils.StarRocksHelper import StarRocksHelper
class ImgIdIndexToDoris(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgIdIndexToDoris, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_id_index = self.spark.sql(f"select 1+1;")
self.table_name = "img_dwd_id_index"
self.table_save = "img_id_index_copy"
def read_data(self):
sql = f"select id, index, img_unique, site_name, img_type from {self.table_name} where site_name='{self.site_name}' and img_type = '{self.img_type}';"
print("sql:", sql)
self.df_id_index = self.spark.sql(sql).cache()
self.df_id_index.show(10)
print(f"self.df_id_index.count(): {self.df_id_index.count()}")
def handle_data(self):
pass
def save_data(self):
# starrocks_url = "jdbc:mysql://192.168.10.151:19030/selection"
# properties = {
# "user": "fangxingjun",
# "password": "fangxingjun12345",
# "driver": "com.mysql.cj.jdbc.Driver",
# # "driver": "com.mysql.cj.jdbc.Driver",
# }
# self.df_id_index.write.jdbc(url=starrocks_url, table="image_id_index", mode="overwrite", properties=properties)
# self.df_id_index = self.df_id_index.withColumn('created_time', F.lit(datetime.datetime.now()))
# self.df_id_index = self.df_id_index.withColumn("img_type", F.col("img_type").cast("int"))
# StarRocksHelper.spark_export(df_save=self.df_id_index, db_name='selection', table_name='image_id_index')
df_save = self.df_id_index.toPandas()
df_save.to_sql(self.table_save, con=self.engine_doris, if_exists="append", index=False, chunksize=10000)
if __name__ == '__main__':
site_name = sys.argv[1]
img_type = sys.argv[2]
handle_obj = ImgIdIndexToDoris(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
import sys
import time
import traceback
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.db_util import DbTypes, DBUtil
class AsinImageLocalPath(object):
def __init__(self, site_name='us', img_type='amazon_self'):
self.site_name = site_name
self.img_type = img_type
self.first_local_dir, self.image_table = self.get_first_local_dir()
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
def get_first_local_dir(self):
if self.img_type == 'amazon_self':
first_local_dir = f"/mnt/data/img_data/amazon_self/{self.site_name}"
image_table = f'{self.site_name}_self_asin_image'
elif self.img_type == 'amazon':
first_local_dir = f"/mnt/data/img_data/amazon/{self.site_name}"
image_table = f'{self.site_name}_amazon_image'
elif self.img_type == 'amazon_inv':
first_local_dir = f"/mnt/data/img_data/amazon_inv/{self.site_name}"
image_table = f'{self.site_name}_inv_img'
else:
first_local_dir = ""
image_table = ""
return first_local_dir, image_table
def read_data(self):
# sql = f"select img_unique from selection.{self.image_table} where state=1;"
# df = pd.read_sql(sql, con=self.engine_doris)
sql = f"SELECT * from us_inv_img_info WHERE updated_at>= CURDATE() - INTERVAL 3 DAY and state =3;"
df = pd.read_sql(sql, con=self.engine_mysql)
print(f"sql: {sql}", df.shape)
return df
def handle_data(self, df):
# f"{id_segment}_{id}_{img_id}_{img_type}.jpg"
# df['img_unique'] = df['id_segment'] + "_" + df['id'] + "_" + df['img_id'] + "_" + df['img_type']
df['img_unique'] = df['id_segment'].astype(str) + "_" + df['id'].astype(str) + "_" + df['img_id'].astype(str) + "_" + df['img_type']
if self.img_type in ['amazon_self', 'amazon']:
df["local_path"] = df.img_unique.apply(lambda x: f"{self.first_local_dir}/{x[:1]}/{x[:2]}/{x[:3]}/{x[:4]}/{x[:5]}/{x[:6]}/{x}.jpg")
elif self.img_type in ['amazon_inv']:
df["local_path"] = df.img_unique.apply(lambda x: f"{self.first_local_dir}/{x.split('_')[0]}/{x}.jpg")
df["img_type"] = self.img_type
df["site_name"] = self.site_name
df["state"] = 1
df = df.loc[:, ["img_unique", "site_name", "local_path", "img_type", "state"]]
print(f"此次更新图片的数量: {df.shape}", df.head(5))
# quit()
return df
def save_data(self, df):
df.to_sql("img_local_path", con=self.engine_doris, if_exists="append", index=False)
def run(self):
df = self.read_data()
df = self.handle_data(df)
self.save_data(df)
if __name__ == '__main__':
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
# site_name = 'us'
# img_type = "amazon_inv"
handle_obj = AsinImageLocalPath(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
import re
import sys
import traceback
import numpy as np
import pandas as pd
import faiss
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.db_util import DbTypes, DBUtil
from vgg_model import VGGNet
from curl_cffi import requests, Curl
from lxml import etree
class ImgSearch():
def __init__(self, site_name='us', search_key='asin', search_value='B0BBNQCXZL', top_k=100):
self.site_name = site_name
self.search_key = search_key
self.search_value = search_value
self.top_k = top_k
self.server_ip = "113.100.143.162:8000"
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.vgg_model = VGGNet()
self.index_path = rf"/mnt/data/img_data/img_index/us/amazon_inv/knn.index"
self.my_index = self.load_index() # 加载索引 -- 后续更新
def load_index(self):
my_index = faiss.read_index(self.index_path)
print("加载索引完成, type(my_index):", type(my_index))
return my_index
def get_img_feature(self, search_key, search_value, file_path, site_name='us', img_type='amazon_inv'):
print("search_key, search_value, file_path:", search_key, search_value, file_path)
if search_key == 'img_unique':
sql = f"select * from img_features where site_name='{site_name}' and img_type='{img_type}' and img_unique='{search_value}';"
df = pd.read_sql(sql, con=self.engine_doris)
if df.shape[0] == 0:
if len(search_value) == 10:
query_vector = self.download_img(asin=search_value, site=site_name)
else:
query_vector = []
else:
query_vector = eval(list(df.features)[0])
return query_vector
elif search_key == 'file':
query_vector = self.vgg_model.vgg_extract_feat(file=file_path, file_type='bytes')
return query_vector
else:
print("不符合的查询方式,仅支持img_unique和图片file两种方式")
def calculate_similarity(self, index_list, query_vector, site_name='us', img_type='amazon_inv'):
sql = f"""
SELECT main.img_unique, a.features FROM
(select img_unique from img_id_index where `index` in {index_list}) main
inner join img_features a
on main.img_unique=a.img_unique
where a.site_name='{site_name}' and img_type='{img_type}';
"""
df = pd.read_sql(sql, con=self.engine_doris)
df = df.drop_duplicates(['img_unique'])
all_vecs_dict = {img_unique: eval(features) for img_unique, features in zip(df.img_unique, df.features)}
similarities = self.cosine_similarity_matrix(query_vec=query_vector, all_vecs=list(all_vecs_dict.values()))
# similarities_dict = {asin: (similarity, title, f"http://{self.server_ip}/images/{site_name}/{asin[:1]}/{asin[:2]}/{asin[:3]}/{asin[:4]}/{asin[:5]}/{asin[:6]}/{asin}.jpg") for similarity, title, asin in zip(similarities, df.title, df.asin) if asin}
# similarities_dict = {
# img_unique: (similarity,
# f"http://{self.server_ip}/images/{img_type}/{site_name}/{img_unique.split('_')[0]}/{img_unique}.jpg")
# for similarity, img_unique in zip(similarities, df.img_unique) if img_unique
# }
similarities_list = []
for similarity, img_unique in zip(similarities, df.img_unique):
img_url = f"http://{self.server_ip}/images/{img_type}/{site_name}/{img_unique.split('_')[0]}/{img_unique}.jpg"
img_unique = re.sub(r'_', '@@', img_unique, count=3)
# print(f"img_unique: {img_unique}")
if len(img_unique.split("@@")) == 4:
# if len(img_unique.split("@@")) == 4 and similarity > 0:
img_id, img_type_ = img_unique.split("@@")[-2], img_unique.split("@@")[-1]
similarities_list.append(
{
"img_id": img_id,
"img_type": img_type_,
"similarity": similarity,
"img_url": img_url,
}
)
df_similarities = pd.DataFrame(similarities_list)
return df_similarities
def cosine_similarity_matrix(self, query_vec, all_vecs):
# 计算相似度
query_vec_norm = np.linalg.norm(query_vec)
all_vecs_norm = np.linalg.norm(all_vecs, axis=1)
print(query_vec_norm.shape)
print(all_vecs_norm.shape)
dot_products = np.dot(all_vecs, query_vec)
similarities = dot_products / (query_vec_norm * all_vecs_norm)
# 将相似度转换为百分比
similarities_percentage = similarities * 100
# 保留所需的小数位数,例如保留两位小数
similarities_percentage = np.round(similarities_percentage, 2)
return similarities_percentage
def cosine_similarity_two_img(self, file1_path, file2_path):
query_vector1 = self.vgg_model.vgg_extract_feat(file=file1_path, file_type='bytes')
query_vector2 = self.vgg_model.vgg_extract_feat(file=file2_path, file_type='bytes')
query_vector1_norm = np.linalg.norm(query_vector1) # 向量1的模
query_vector2_norm = np.linalg.norm(query_vector2) # 向量2的模
dot_product = np.dot(query_vector1, query_vector2) # 向量的点积
similarity = dot_product / (query_vector1_norm * query_vector2_norm)
# similarity = np.float32(similarity)
similarity = np.round(similarity * 100, 2)
return similarity
def search_api(self, search_key, search_value, file_path, site_name='us', img_type='amazon_inv', top_k=100):
# 获取要查询的asin/图片文件的向量
query_vector = self.get_img_feature(search_key, search_value, file_path, site_name, img_type=img_type)
self.query_vector = np.array(query_vector)
self.query_vector = self.query_vector.reshape(1, -1)
distances, indices = self.my_index.search(self.query_vector, top_k)
index_list = tuple(indices.tolist()[0])
# 计算相似度
similarities_dict = self.calculate_similarity(site_name=site_name, img_type=img_type, index_list=index_list, query_vector=query_vector)
return similarities_dict
def download_img(self, asin, site='us'):
for i in range(5):
try:
if site == 'us':
asin_url = f'https://www.amazon.com/dp/{asin}'
if site == "us":
asin_url = f'https://www.amazon.com/dp/{asin}'
elif site == 'uk':
asin_url = f'https://www.amazon.co.uk/dp/{asin}' # 站点url
elif site == 'de':
asin_url = f'https://www.amazon.de/dp/{asin}'
elif site == 'fr':
asin_url = f'https://www.amazon.fr/dp/{asin}'
elif site == 'es':
asin_url = f'https://www.amazon.es/dp/{asin}'
elif site == 'it':
asin_url = f'https://www.amazon.it/dp/{asin}'
headers = {
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'zh-CN,zh;q=0.9',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache',
'Sec-Ch-Ua-Mobile': '?0', 'Sec-Ch-Ua-Platform': '"Windows"',
'Sec-Ch-Ua-Platform-Version': '"10.0.0"',
'Sec-Fetch-Dest': 'document', 'Sec-Fetch-Mode': 'navigate', 'Sec-Fetch-Site': 'none',
'Sec-Fetch-User': '?1',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36',
'Viewport-Width': '1920'}
curl = Curl(cacert="/path/to/your/cert")
session = requests.Session(curl=curl)
print('asin_url::', asin_url)
resp = session.get(asin_url, headers=headers, timeout=10, verify=False, impersonate="chrome110") # 获取网页
image_url = self.xpath_imgurl(resp.text)
if image_url:
str_ = re.findall(r'\._(.*)\.jpg', image_url)[0]
_url = image_url.replace('._' + str_, '')
print('获取图片url:', _url)
resp_img = session.get(_url, headers=headers, timeout=10, verify=False,
impersonate="chrome110") # 获取网页
asin_upper = asin.upper()
path_1 = fr"/mnt/data/img_data/tmp_img/"
if os.path.exists(path_1) == False: # 判断路径是否存在
os.makedirs(path_1)
asin_path = rf"{path_1}{asin_upper}.jpg"
with open(asin_path, 'wb') as f: # 打开写入到path路径里-二进制文件,返回的句柄名为f
f.write(resp_img.content) # 往f里写入r对象的二进制文件
query_vec = self.vgg_model.vgg_extract_feat(file=asin_path, file_type='bytes')
return query_vec
except Exception as e:
print(e, f"\n{traceback.format_exc()}")
# return query_vec
def xpath_imgurl(self, resp):
print(2222222222222222222)
img_url_list = ["//div[@id='imgTagWrapperId']/img/@src", '//div[@id="img-canvas"]/img/@src',
'//div[@id="imgTagWrapperId"]/img/@src', '//div[@id="img-canvas"]/img/@src',
'//div[@class="image-wrapper"]/img/@src', '//div[@id="mainImageContainer"]/img/@src',
'//img[@id="js-masrw-main-image"]/@src', '//div[@id="a2s-dp-main-content"]//img/@src',
'//div[@id="ppd-left"]//img/@src',
'//div[@id="ebooks-img-canvas"]//img[@class="a-dynamic-image frontImage"]/@src',
'//div[@id="ppd-left"]//img/@src', '//div[@class="a-column a-span12 a-text-center"]/img/@src',
'//img[0]/@src', '//img[@id="seriesImageBlock"]/@src', '//img[@class="main-image"]/@src',
'//img[@class="mainImage"]/@src', '//img[@id="gc-standard-design-image"]/@src',
'//div[@class="a-row a-spacing-medium"]//img[1]/@src',
'//div[@class="a-image-container a-dynamic-image-container greyBackground"]/img/@src']
response_s = etree.HTML(resp)
for i in img_url_list:
image = response_s.xpath(i)
if image:
image_url = image[0]
break
else:
image_url = None
return image_url
import os
import pandas as pd
from flask import Flask, request, jsonify
from img_search import ImgSearch
import time
img_search = ImgSearch()
app = Flask(__name__)
UPLOAD_FOLDER = '/mnt/data/img_data/tmp_img/'
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
@app.route('/img_search/similarity', methods=['POST'])
def search_amazon_similarity():
print(request.form) # 打印所有收到的表单数据
print(request.files) # 打印所有收到的文件
file_list = ['file1', 'file2']
files = []
for file_key in file_list:
if file_key in request.files:
files.append(request.files[file_key])
if len(files) != 2:
return jsonify({"error": "只能上传 2 个文件"}), 400
else:
# file1_path
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{files[0].filename}"
file1_path = os.path.join(UPLOAD_FOLDER, filename)
files[0].save(file1_path)
# file2_path
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{files[1].filename}"
file2_path = os.path.join(UPLOAD_FOLDER, filename)
files[1].save(file2_path)
similarity = img_search.cosine_similarity_two_img(
file1_path=file1_path,
file2_path=file2_path
)
return jsonify(
{"similarity": similarity}
)
@app.route('/img_search', methods=['POST'])
def search_amazon():
print(request.form) # 打印所有收到的表单数据
print(request.files) # 打印所有收到的文件
site_name = request.form.get('site_name', 'us') # 默认值为 'us'
img_type = request.form.get('img_type', 'amazon_inv') # 默认值为 'amazon_inv'
# search_key = request.form.get('search_key', 'asin') # 默认值为 'asin'
search_key = request.form.get('search_key', 'img_unique') # 默认值为 'img_unique' -- file
search_value = request.form.get('search_value', 'B0BBNQCXZL') # 默认值为 'B0BBNQCXZL' --
top_k = int(request.form.get('top_k', 100)) # 默认值为 100
df_similarities_list = []
# 获取文件并存储
if search_key == 'file':
# file = request.files['file'] # 获取名为 "file" 的文件流参数
file_list = ['file1', 'file2', 'file3', 'file4', 'file5']
files = []
for file_key in file_list:
if file_key in request.files:
files.append(request.files[file_key])
# 如果传入的文件超过 5 个,可以考虑返回错误或警告
if len(files) > 5:
return jsonify({"error": "最多只能上传 5 个文件"}), 400
# files = request.files.getlist('file') # 获取名为 "file" 的多个文件
for file in files:
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{file.filename}"
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
# 打印日志,便于调试
print("file saved at:", file_path)
# 使用每张图片进行搜索
df_similarities = img_search.search_api(
site_name=site_name, img_type=img_type, search_key=search_key,
search_value=search_value, top_k=top_k, file_path=file_path
)
df_similarities_list.append(df_similarities)
# similarities_dict_list.append({
# "similarities_dict": similarities_dict
# })
else:
file_path = ''
# 使用每张图片进行搜索
df_similarities = img_search.search_api(
site_name=site_name, img_type=img_type, search_key=search_key,
search_value=search_value, top_k=top_k, file_path=file_path
)
df_similarities_list.append(df_similarities)
# similarities_dict_list.append({
# "similarities_dict": similarities_dict
# })
df_result = pd.concat(df_similarities_list)
df_result = df_result.loc[~(df_result.similarity.isna())] # 过滤相似度为NaN的记录
print(f"df_result.columns: {df_result.columns}")
if df_result.shape[0]:
df_result.sort_values(["similarity"], ascending=[False], inplace=True)
# 新增 id 列,从 1 开始
df_result['id'] = range(1, len(df_result) + 1)
df_result = df_result.iloc[: top_k]
result = []
for _, row in df_result.iterrows():
row_dict = {'id': row['id']} # 创建一个字典,首先将 id 添加为 key
for col in df_result.columns:
if col != 'id': # 排除 id 列
row_dict[col] = row[col] # 将其他列添加到字典中
result.append(row_dict)
return jsonify(
result
)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=10001)
import ast
import json
import os
import sys
import threading
import time
import traceback
import uuid
import pandas as pd
import logging
import redis
from sqlalchemy import text
sys.path.append("/opt/module/spark-3.2.0-bin-hadoop3.2/demo/py_demo/")
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
class ImgToh7(object):
def __init__(self, site_name='us', thread_num=10, limit=100, img_type='amazon_self'):
self.site_name = site_name
self.thread_num = thread_num
self.limit = limit
self.img_type = img_type
self.img_table = self.get_table_name_and_dir_name()
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.df_self_asin_image = pd.DataFrame()
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
self.local_name = f"{self.site_name}_img_to_h7"
def get_table_name_and_dir_name(self):
if self.img_type == 'amazon_self':
image_table = f'{self.site_name}_self_asin_img'
elif self.img_type == 'amazon_inv':
image_table = f'{self.site_name}_inv_img'
elif self.img_type == 'amazon':
image_table = f'{self.site_name}_amazon_img'
else:
image_table = ""
return image_table
def get_dir_name(self, img_unique_value=''):
if self.img_type in ['amazon_self', 'amazon']:
dir_name = rf"/mnt/data/img_data/{self.img_type}/{self.site_name}/{img_unique_value[:1]}/{img_unique_value[:2]}/{img_unique_value[:3]}/{img_unique_value[:4]}/{img_unique_value[:5]}/{img_unique_value[:6]}"
elif self.img_type == 'amazon_inv':
dir_name = rf"/mnt/data/img_data/{self.img_type}/{self.site_name}/{img_unique_value.split('_')[0]}"
else:
dir_name = ""
return dir_name
def acquire_lock(self, lock_name, timeout=10):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name, timeout=100)
if lock_acquired:
with self.engine_doris.begin() as conn:
# sql_read = text(f"SELECT id, img_unique, local_path, img_type FROM selection.{self.img_table} WHERE site_name='{self.site_name}' and img_type='{self.img_type}' and state=1 LIMIT {self.limit};")
# sql_read = f"""
# SELECT id, img_unique, image_file from selection.{self.img_table} WHERE id in
# (
# select id from selection.{self.img_table} WHERE state=1 limit {self.limit}
# );
# """
sql_read = f"""
SELECT id, img_unique, image_file from selection.{self.img_table} WHERE id in
(
select id from selection.us_inv_img WHERE state=1 and id
BETWEEN (SELECT min(id) from selection.{self.img_table} WHERE state=1 )
and (SELECT min(id)+{self.limit} from selection.{self.img_table} WHERE state=1 )
);
"""
# a = conn.execute(sql_read)
# df = pd.DataFrame(a, columns=['id', 'img_unique', 'image_file'])
df = pd.read_sql(sql_read, con=self.engine_doris)
df_img_file_null = df.loc[df.image_file.isna()]
df_img_file_not_null = df.loc[~df.image_file.isna()]
img_unique_tuple = tuple(df.img_unique)
img_unique_tuple_null = tuple(df_img_file_null.img_unique)
img_unique_tuple_not_null = tuple(df_img_file_not_null.img_unique)
print(f"sql_read: {sql_read}, {df.shape}, df_img_file_null.shape: {df_img_file_null.shape}", img_unique_tuple[:10])
# 更新状态:
if img_unique_tuple_not_null:
img_unique_tuple_not_null_str = f"('{img_unique_tuple_not_null[0]}')" if len(img_unique_tuple_not_null) == 1 else f"{img_unique_tuple_not_null}"
# sql_update = f"update selection.{self.img_table} set state=2 where img_unique in ({','.join(map(str, img_unique_tuple_not_null))});" # 解析存储中
sql_update = f"update selection.{self.img_table} set state=2 where img_unique in {img_unique_tuple_not_null_str};" # 解析存储中
print("sql_update--2:", sql_update[:100])
conn.execute(sql_update)
# 更新状态:
if img_unique_tuple_null:
img_unique_tuple_null_str = f"('{img_unique_tuple_null[0]}')" if len(img_unique_tuple_null) == 1 else f"{img_unique_tuple_null}"
# sql_update = f"update selection.{self.img_table} set state=9 where img_unique in ({','.join(map(str, img_unique_tuple))});" # img_file字段为空值
sql_update = f"update selection.{self.img_table} set state=9 where img_unique in {img_unique_tuple_null_str};" # img_file字段为空值
print("sql_update--9:", sql_update[:100])
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df_img_file_not_null, df_img_file_null
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(5) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(5)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
continue
def read_data_old(self):
with self.engine_doris.begin() as conn:
sql_update = f"""
UPDATE {self.img_table}
SET state = 2
WHERE id in (
SELECT id
FROM {self.img_table}
WHERE state = 1
LIMIT {self.limit}
);
"""
print(f"sql_update: {sql_update}")
conn.execute(sql_update)
sql_read = f"select id, img_unique, image_file from {self.img_table} where state=2 limit {self.limit};"
# self.df_self_asin_image = pd.read_sql(sql_read, con=self.engine_doris)
a = conn.execute(sql_read)
df = pd.DataFrame(a, columns=['id', 'img_unique', 'image_file'])
id_tuple = tuple(df.id)
print(f"sql_read: {sql_read}, {df.shape}", id_tuple[:10])
return df
# if id_tuple:
# id_tuple_str = f"({id_tuple[0]})" if len(id_tuple) == 1 else f"{id_tuple}"
# sql_update = f"update {self.site_name}_self_asin_image set state=2 where id in {id_tuple_str};"
# conn.execute(sql_update)
def handle_data(self, df):
img_unique_list = list(df.img_unique)
img_str_list = list(df.image_file)
print(f"{len(img_unique_list)}, {img_unique_list[:10]}")
for img_unique_value, img_str in zip(img_unique_list, img_str_list):
# print(f"img_unique_value, img_str: {img_unique_value, img_str}")
# print(f"img_str: {type(img_str)}, {img_str[:20]}")
# img_str = json.loads(img_str)
input_bytes = ast.literal_eval(img_str)
# input_bytes = img_str
dir_name = self.get_dir_name(img_unique_value=img_unique_value)
# 确保目录存在
os.makedirs(dir_name, exist_ok=True)
file_name = rf"{dir_name}/{img_unique_value}.jpg"
# print(f"file_name: {file_name}")
with open(file_name, 'wb') as f:
f.write(input_bytes)
# break
def save_data(self, df):
with self.engine_doris.begin() as conn:
img_unique_tuple = tuple(df.img_unique)
if img_unique_tuple:
img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
sql_update = f"update {self.img_table} set state=3 where img_unique in {img_unique_tuple_str};"
print(f"sql_update: {sql_update[:100]}")
conn.execute(sql_update)
def run(self):
while True:
try:
df, df_null = self.read_data()
if df.shape[0] or df_null.shape[0]:
self.handle_data(df)
self.save_data(df)
# break
else:
break
except Exception as e:
print(f"error: {e}", traceback.format_exc())
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
time.sleep(10)
def run_thread(self):
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run)
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
limit = int(sys.argv[3]) # 参数3:每次读取的数量--1000
thread_num = int(sys.argv[4]) # 参数4:线程数量--5
# handle_obj = ImgToh7(site_name='us', img_type='amazon_inv', limit=1, thread_num=1)
handle_obj = ImgToh7(site_name=site_name, img_type=img_type, limit=limit, thread_num=thread_num)
handle_obj.run_thread()
\ No newline at end of file
import requests
# 服务器地址
url = 'http://192.168.10.217:10001/img_search'
# 文件路径列表,上传多个文件
file_paths = [
'D:\Amazon-Selection\pyspark_job\image_search\img/1.jpg',
# 'D:\Amazon-Selection\pyspark_job\image_search\img/1.png',
'D:\Amazon-Selection\pyspark_job\image_search\img/2.png',
]
# 将多个文件添加到请求中
files = [('file', (open(file_path, 'rb'))) for file_path in file_paths]
# 其他表单数据
data = {
'site_name': 'us',
'img_type': 'amazon_inv',
'search_key': 'file', # 使用文件方式进行查询
'search_value': '', # 在图片查询中可忽略
'top_k': 5 # 设置查询结果的返回数量
}
# 发送请求
response = requests.post(url, files=files, data=data)
# 打印响应
if response.status_code == 200:
print(response.json())
else:
print(f"Error: {response.status_code}, {response.text}")
"""
提取图片特征信息的模型:vgg16
"""
import io
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling, include_top=False)
# self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
# 提取vgg16最后一层卷积特征
def vgg_extract_feat(self, file, file_type='file'):
# if file_type == 'bytes':
# img = image.load_img(io.BytesIO(file.read()), target_size=(self.input_shape[0], self.input_shape[1]))
# else:
# # file_type = 'file'
# img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
# print(feat.shape)
norm_feat = feat[0] / LA.norm(feat[0])
# print("norm_feat:", norm_feat)
return list(norm_feat)
# # 提取vgg16最后一层卷积特征
# def vgg_extract_feat(self, image_bytes):
# try:
# img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
# img = image.img_to_array(img)
# img = np.expand_dims(img, axis=0)
# img = preprocess_input_vgg(img)
# feat = self.model_vgg.predict(img)
# norm_feat = feat[0] / LA.norm(feat[0])
# return norm_feat.tolist()
# except Exception as e:
# print(e, traceback.format_exc())
# return list(np.zeros(shape=(512,)))
if __name__ == '__main__':
handle_obj = VGGNet()
handle_obj.vgg_extract_feat(file='')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment