Commit f7c24024 by fangxingjun

no message

parent 781a04a7
import os
import sys
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
from utils.db_util import DbTypes, DBUtil
class ImgAlterTableName(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgAlterTableName, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
def read_data(self):
pass
def handle_data(self):
with self.engine_doris.begin() as conn:
sql1 = "ALTER TABLE img_id_index RENAME img_id_index_temp;"
conn.execute(sql1)
sql2 = "ALTER TABLE img_id_index_copy RENAME img_id_index;"
conn.execute(sql2)
sql3 = "ALTER TABLE img_id_index_temp RENAME img_id_index_copy;"
conn.execute(sql3)
print(f"交换表名称完成--sql1: {sql1}\nsql2: {sql2}\nsql3: {sql3}")
def save_data(self):
pass
if __name__ == '__main__':
site_name = sys.argv[1]
img_type = sys.argv[2]
handle_obj = ImgAlterTableName(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
from autofaiss import build_index
from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel
from pyspark import SparkConf, SparkContext
def create_spark_session():
# this must be a path that is available on all worker nodes
# os.environ['PYSPARK_PYTHON'] = "/opt/module/spark/demo/py_demo/img_search/autofaiss.pex"
spark = (
SparkSession.builder
.config("spark.executorEnv.PEX_ROOT", "./.pex")
.config("spark.executor.cores", "4")
.config("spark.executor.memory", "20G") # make sure to increase this if you're using more cores per executor
.config("spark.num.executors", "10")
.config("spark.yarn.queue", "spark")
.master("local") # this should point to your master node, if using the tunnelling version, keep this to localhost
.appName("autofaiss-create-index")
.getOrCreate()
)
return spark
spark = create_spark_session()
index, index_infos = build_index(
# embeddings="hdfs://nameservice1:8020/home/img_search/us/amazon_inv/parquet",
embeddings="hdfs://nameservice1:8020/home/img_search/img_parquet/us/amazon_inv",
distributed="pyspark",
file_format="parquet",
max_index_memory_usage="80G", # 16G
current_memory_available="120G", # 24G
temporary_indices_folder="hdfs://nameservice1:8020/home/img_search/img_tmp/us/amazon_inv//distributed_autofaiss_indices",
index_path="hdfs://nameservice1:8020/home/img_search/img_index/us/amazon_inv/knn.index",
index_infos_path="hdfs://nameservice1:8020/home/img_search/img_index/us/amazon_inv/infos.json",
)
print("index, index_infos:", index, index_infos)
import ast
import os
import sys
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
class PicturesDimFeaturesSlice(Templates):
def __init__(self, site_name='us', img_type='amazon_inv'):
super(PicturesDimFeaturesSlice, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.db_save = f'img_dim_features_slice'
self.spark = self.create_spark_object(
app_name=f"{self.db_save}: {self.site_name}")
self.df_asin_features = self.spark.sql(f"select 1+1;")
self.df_save = self.spark.sql(f"select 1+1;")
# self.partitions_by = ['site_name', 'block']
self.partitions_by = ['site_name', 'img_type']
self.partitions_num = 10
def read_data(self):
# sql = f"select id, asin, img_vector as embedding from ods_asin_extract_features;"
sql = f"select id, img_unique, features, img_type from img_ods_features where site_name='{self.site_name}' and img_type='{self.img_type}';"
print("sql:", sql)
self.df_save = self.spark.sql(sql).cache()
self.df_save.show(10)
print(f"self.df_save.count(): {self.df_save.count()}")
# 由于不需要在这一步生成array类型
# partitions_num = self.df_asin_features.rdd.getNumPartitions()
# print("分区数量:", partitions_num) # 642
# # self.partitions_num = 1000
# self.df_save = self.df_save.repartition(self.partitions_num)
# print("重置分区数量:", self.partitions_num) # 642
def handle_data(self):
# 定义一个将字符串转换为列表的UDF
# str_to_list_udf = F.udf(lambda s: ast.literal_eval(s), ArrayType(FloatType()))
# # 对DataFrame中的列应用这个UDF
# self.df_save = self.df_save.withColumn("embedding", str_to_list_udf(self.df_save["embedding"]))
self.df_save = self.df_save.withColumn('site_name', F.lit(self.site_name))
if __name__ == '__main__':
site_name = sys.argv[1] # 参数1:站点
handle_obj = PicturesDimFeaturesSlice(site_name=site_name)
handle_obj.run()
import os
import socket
import sys
import threading
import logging
import time
import traceback
import uuid
import pandas as pd
import redis
import requests
from sqlalchemy import text
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.db_util import DbTypes, DBUtil
class ImgDownload(object):
def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10, limit=200):
self.site_name = site_name
self.img_type = img_type
self.thread_num = thread_num
self.limit = limit
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.hostname = socket.gethostname()
self.first_local_dir, self.read_table = self.get_first_local_dir()
# self.read_table = f"{self.site_name}_inv_img_info"
self.local_name = self.read_table
def get_first_local_dir(self):
if self.img_type == 'amazon_self':
first_local_dir = f"/mnt/data/img_data/amazon_self/{self.site_name}"
image_table = f'{self.site_name}_self_asin_image'
elif self.img_type == 'amazon':
first_local_dir = f"/mnt/data/img_data/amazon/{self.site_name}"
image_table = f'{self.site_name}_amazon_image'
elif self.img_type == 'amazon_inv':
first_local_dir = f"/mnt/data/img_data/amazon_inv/{self.site_name}"
image_table = f'{self.site_name}_inv_img_info'
else:
first_local_dir = ""
image_table = ""
return first_local_dir, image_table
def acquire_lock(self, lock_name, timeout=100):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
@staticmethod
def img_download(img_url, img_path, img_name):
file_path = f"{img_path}{img_name}"
for d_num in range(5):
try:
response = requests.get(img_url)
if response.status_code == 200:
# Open a file in binary write mode
with open(file_path, 'wb') as file:
file.write(response.content)
# print("Image downloaded successfully.")
return True
else:
continue
except Exception as e:
error = "No such file or directory"
if error in str(e):
os.makedirs(img_path)
print(f"{d_num}次--下载图片失败, 图片路径: {file_path}, 图片url: {img_url}, \n错误信息: {e, traceback.format_exc()}")
time.sleep(2)
return False
def update_state(self, id_list, state, state_value="success"):
if id_list:
while True:
try:
with self.engine_mysql.begin() as conn:
id_tuple = tuple(id_list)
print(f"{state_value}--id_tuple: {len(id_tuple)}, {id_tuple[:10]}", )
if id_tuple:
id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
sql_update = f"UPDATE {self.read_table} SET state={state} WHERE id IN {id_tuple_str};"
print("sql_update:", sql_update[:150])
conn.execute(sql_update)
break
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(20)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
continue
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
print("self.hostname:", self.hostname)
with self.engine_mysql.begin() as conn:
sql_read = text(f"SELECT id, img_id, img_type, img_url, id_segment FROM {self.read_table} WHERE state=1 LIMIT {self.limit};")
df = pd.read_sql(sql=sql_read, con=self.engine_mysql)
id_tuple = tuple(df.id)
print(f"sql_read: {sql_read}, {df.shape}", id_tuple[:10])
if id_tuple:
id_tuple_str = f"('{id_tuple[0]}')" if len(id_tuple) == 1 else f"{id_tuple}"
sql_update = f"UPDATE {self.read_table} SET state=2 WHERE id IN {id_tuple_str};"
print("sql_update:", sql_update[:150])
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(10) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(20)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
continue
def handle_data(self, df, thread_id):
# 1. 下载图片
img_success_id_list = []
img_failed_id_list = []
id_list = list(df.id)
id_len = len(id_list)
for id_segment, id, img_id, img_type, img_url in zip(df.id_segment, df.id, df.img_id, df.img_type, df.img_url):
img_path = f"{self.first_local_dir}/{id_segment}/"
img_name = f"{id_segment}_{id}_{img_id}_{img_type}.jpg"
if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
img_path = img_path.replace("/mnt", "/home")
d_flag = self.img_download(img_url=img_url, img_path=img_path, img_name=img_name)
id_index = id_list.index(id)
print(f"self.hostname: {self.hostname}, 线程: {thread_id}, 是否成功: {d_flag}, id_index: {id_index}, 进度: {round(id_index/id_len * 100, 2)}%, img_path: {img_path}{img_name}")
if d_flag:
img_success_id_list.append(id)
else:
img_failed_id_list.append(id)
# 2. 更改状态 -- 成功3 失败4
print(f"success: {len(img_success_id_list)}, failed: {len(img_failed_id_list)}")
self.update_state(id_list=img_success_id_list, state=3, state_value="success")
self.update_state(id_list=img_failed_id_list, state=4, state_value="failed")
def save_data(self):
pass
def run(self, thread_id=1):
while True:
try:
df = self.read_data()
if df.shape[0]:
self.handle_data(df=df, thread_id=thread_id)
self.save_data()
# break
else:
break
except Exception as e:
print(e, traceback.format_exc())
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
time.sleep(20)
continue
def run_thread(self):
logging.info("所有线程处理开始")
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run, args=(thread_id, ))
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
# handle_obj = PicturesFeatures(self_flag='_self')
# site_name = int(sys.argv[1]) # 参数1:站点
# site_name = 'us'
# img_type = "amazon_inv"
# limit = 100
# thread_num = 1
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
limit = int(sys.argv[3]) # 参数3:每次读取的数量--1000
thread_num = int(sys.argv[4]) # 参数4:线程数量--5
handle_obj = ImgDownload(site_name=site_name, img_type=img_type, thread_num=thread_num, limit=limit)
# handle_obj.run()
handle_obj.run_thread()
\ No newline at end of file
import multiprocessing
import os
import sys
import time
import traceback
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
from utils.db_util import DbTypes, DBUtil
class JudgeFinished(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(JudgeFinished, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.tn_pics_hdfs_index = f"img_hdfs_index"
def judge(self):
sql = f"select * from {self.tn_pics_hdfs_index} where state in (1, 2) and site_name='{self.site_name}' and img_type='{self.img_type}';"
df = pd.read_sql(sql, con=self.engine_doris)
print(f"sql: {sql}, {df.shape}")
result_flag = True if df.shape[0] else False
return result_flag
def main(site_name='us', img_type='amazon_inv', p_num=0):
while True:
try:
judge_obj = JudgeFinished(site_name=site_name, img_type=img_type)
result_flag = judge_obj.judge()
if result_flag:
print(f"继续, result_flag: {result_flag}")
os.system(f"/opt/module/spark/bin/spark-submit --master yarn --driver-memory 1g --executor-memory 4g --executor-cores 1 --num-executors 1 --queue spark /opt/module/spark/demo/py_demo/img_search/img_dwd_id_index.py {site_name} {img_type}")
else:
print(f"结束, result_flag: {result_flag}")
break
except Exception as e:
print(e, traceback.format_exc())
time.sleep(20)
error = "ValueError: Length mismatch: Expected axis has 0 elements"
if error in e:
print(f"当前已经跑完所有block块id对应的index关系,退出进程-{p_num}")
quit()
continue
if __name__ == "__main__":
site_name = sys.argv[1]
img_type = sys.argv[2]
process_num = int(sys.argv[3]) # 参数1:进程数
processes = []
for p_num in range(process_num): # 用于设定进程数量
process = multiprocessing.Process(target=main, args=(site_name, img_type, p_num))
process.start()
processes.append(process)
# 等待所有进程完成
for process in processes:
process.join()
import os
import sys
import threading
import time
import traceback
import socket
import uuid
import numpy as np
import pandas as pd
import redis
import logging
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
# from utils.templates import Templates
from sqlalchemy import text
from vgg_model import VGGNet
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)
class ImgExtractFeatures(object):
def __init__(self, site_name='us', img_type="amazon_inv", thread_num=10, limit=1000):
# super(ImgFeatures, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.thread_num = thread_num
self.limit = limit
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.local_name = f"{self.site_name}_img_features"
self.vgg_model = VGGNet()
self.hostname = socket.gethostname()
self.read_table = f"img_local_path"
self.save_table = f"img_features"
def acquire_lock(self, lock_name, timeout=100):
"""
尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
lock_name: 锁的key, 建议和任务名称保持一致
"""
lock_value = str(uuid.uuid4())
lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout) # 可以不设置超时时间
# lock_acquired = self.client.set(lock_name, lock_value, nx=True)
return lock_acquired, lock_value
def release_lock(self, lock_name, lock_value):
"""释放分布式锁"""
script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
result = self.client.eval(script, 1, lock_name, lock_value)
return result
def read_data(self):
while True:
try:
lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
if lock_acquired:
print("self.hostname:", self.hostname)
with self.engine_doris.begin() as conn:
sql_read = text(f"SELECT id, img_unique, local_path, img_type FROM selection.{self.read_table} WHERE site_name='{self.site_name}' and img_type='{self.img_type}' and state=1 LIMIT {self.limit};")
# result = conn.execute(sql_read)
# df = pd.DataFrame(result.fetchall())
df = pd.read_sql(sql=sql_read, con=self.engine_doris)
img_unique_tuple = tuple(df.img_unique)
print(f"sql_read: {sql_read}, {df.shape}", img_unique_tuple[:10])
if img_unique_tuple:
img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
sql_update = text(f"UPDATE selection.{self.read_table} SET state=2 WHERE img_unique IN {img_unique_tuple_str};")
print("sql_update:", sql_update)
conn.execute(sql_update)
self.release_lock(lock_name=self.local_name, lock_value=lock_value)
return df
else:
print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
time.sleep(5) # 等待5s继续访问锁
continue
except Exception as e:
print(f"读取数据错误: {e}", traceback.format_exc())
time.sleep(5)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
continue
def handle_data(self, df, thread_id):
id_list = list(df.id)
img_unique_list = list(df.img_unique)
local_path_list = list(df.local_path)
data_list = []
for id, img_unique, local_path in zip(id_list, img_unique_list, local_path_list):
index = id_list.index(id)
print(f"thread_id, index, id, img_unique, local_path: {thread_id, index, id, img_unique, local_path}")
if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
local_path = local_path.replace("/mnt", "/home")
try:
features = self.vgg_model.vgg_extract_feat(file=local_path)
except Exception as e:
print(e, traceback.format_exc())
features = list(np.zeros(shape=(512,)))
data_list.append([id, img_unique, str(features), self.img_type, self.site_name])
columns = ['id', 'img_unique', 'features', 'img_type', 'site_name']
df_save = pd.DataFrame(data_list, columns=columns)
return df_save
def save_data(self, df):
df.to_sql(self.save_table, con=self.engine_doris, if_exists="append", index=False)
with self.engine_doris.begin() as conn:
img_unique_tuple = tuple(df.img_unique)
if img_unique_tuple:
img_unique_tuple_str = f"('{img_unique_tuple[0]}')" if len(img_unique_tuple) == 1 else f"{img_unique_tuple}"
sql_update = f"update selection.{self.read_table} set state=3 where img_unique in {img_unique_tuple_str};"
print(f"sql_update: {sql_update}")
conn.execute(sql_update)
def run(self, thread_id=1):
while True:
try:
df = self.read_data()
if df.shape[0]:
df_save = self.handle_data(df=df, thread_id=thread_id)
self.save_data(df=df_save)
# break
else:
break
except Exception as e:
print(e, traceback.format_exc())
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='xA9!wL3pZ@q2')
self.vgg_model = VGGNet()
time.sleep(20)
continue
def run_thread(self):
thread_list = []
for thread_id in range(self.thread_num):
thread = threading.Thread(target=self.run, args=(thread_id, ))
thread_list.append(thread)
thread.start()
for thread in thread_list:
thread.join()
logging.info("所有线程处理完成")
if __name__ == '__main__':
# handle_obj = PicturesFeatures(self_flag='_self')
# site_name = int(sys.argv[1]) # 参数1:站点
# site_name = 'us'
# img_type = "amazon_inv"
# limit = 100
# thread_num = 1
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
limit = int(sys.argv[3]) # 参数3:每次读取的数量--1000
thread_num = int(sys.argv[4]) # 参数4:线程数量--5
handle_obj = ImgExtractFeatures(site_name=site_name, img_type=img_type, thread_num=thread_num, limit=limit)
# handle_obj.run()
handle_obj.run_thread()
\ No newline at end of file
import os
import sys
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from utils.db_util import DbTypes, DBUtil
class ImgHdfsIndex(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgHdfsIndex, self).__init__()
self.site_name = site_name
self.img_type = img_type
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.db_save = f'img_hdfs_index'
self.img_dim_features_slice = f'img_dim_features_slice'
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_features = self.spark.sql(f"select 1+1;")
self.df_save = pd.DataFrame()
# self.hdfs_file_path = f'hdfs://nameservice1:8020/home/big_data_selection/dim/{self.img_dim_features_slice}/site_name={self.site_name}/img_type={self.img_type}/'
self.hdfs_file_path = f'hdfs://nameservice1:8020'
self.hdfs_file_list = self.get_hdfs_file_list()
self.index_count = 0
# def get_hdfs_file_list(self):
# # 导入hadoop的包
# java_import(self.spark._jvm, 'org.apache.hadoop.fs.Path')
# # fs = self.spark._jvm.org.apache.hadoop.fs.FileSystem.get(self.spark._jsc.hadoopConfiguration(self.hdfs_file_path))
# # status = fs.listStatus(self.spark._jvm.org.apache.hadoop.fs.Path())
#
# fs = self.spark._jvm.org.apache.hadoop.fs.FileSystem.get(self.spark._jsc.hadoopConfiguration())
# path = self.spark._jvm.org.apache.hadoop.fs.Path(self.hdfs_file_path)
# status = fs.listStatus(path)
#
# hdfs_file_list = [file_status.getPath().getName() for file_status in status]
# return hdfs_file_list
def get_hdfs_file_list(self):
# 使用 os.system 执行 hdfs dfs -ls 命令
command = f"hdfs dfs -ls /home/big_data_selection/dim/img_dim_features_slice/site_name={self.site_name}/img_type={self.img_type}"
result = os.popen(command).read()
# 解析命令输出
file_list = []
for line in result.split('\n'):
if line:
parts = line.split()
if len(parts) > 7:
file_path = parts[-1]
file_list.append(file_path)
print(f"file_list: {(len(file_list))}", file_list[:3])
return file_list
def read_data(self, hdfs_path):
df = self.spark.read.text(hdfs_path)
index_count = df.count()
return df, index_count
def handle_data(self):
pass
def save_data(self):
self.df_save.to_sql(self.db_save, con=self.engine_doris, if_exists="append", index=False)
def run(self):
data_list = []
for hdfs_file in self.hdfs_file_list:
index = self.hdfs_file_list.index(hdfs_file)
hdfs_path = self.hdfs_file_path + hdfs_file
df, index_count = self.read_data(hdfs_path)
data_list.append([index, hdfs_path, index_count, self.index_count])
print([index, hdfs_path, index_count, self.index_count])
self.index_count += index_count
self.df_save = pd.DataFrame(data=data_list, columns=['index', 'hdfs_path', 'current_counts', 'all_counts'])
self.df_save["site_name"] = self.site_name
self.df_save["img_type"] = self.img_type
self.df_save["id"] = self.df_save["index"] + 1
self.df_save["state"] = 1
# self.df_save.to_csv("/root/hdfs_parquet_block_info.csl")
self.save_data()
if __name__ == '__main__':
# site_name = 'us'
# img_type = 'amazon_inv'
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数1:图片类型来源
handle_obj = ImgHdfsIndex(site_name=site_name, img_type=img_type)
handle_obj.run()
import ast
import datetime
import logging
import os
import re
import sys
import threading
import time
import traceback
import pandas as pd
import redis
from pyspark.sql.types import ArrayType, FloatType
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.templates import Templates
# from ..utils.templates import Templates
from py4j.java_gateway import java_import
from sqlalchemy import text
from pyspark.sql import functions as F
import pyarrow as pa
import pyarrow.parquet as pq
from multiprocessing import Process
from multiprocessing import Pool
import multiprocessing
from utils.db_util import DbTypes, DBUtil
from utils.StarRocksHelper import StarRocksHelper
class ImgIdIndexToDoris(Templates):
def __init__(self, site_name='us', img_type="amazon_inv"):
super(ImgIdIndexToDoris, self).__init__()
self.site_name = site_name
self.img_type = img_type
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.spark = self.create_spark_object(app_name=f"{self.db_save}: {self.site_name}")
self.df_id_index = self.spark.sql(f"select 1+1;")
self.table_name = "img_dwd_id_index"
self.table_save = "img_id_index_copy"
def read_data(self):
sql = f"select id, index, img_unique, site_name, img_type from {self.table_name} where site_name='{self.site_name}' and img_type = '{self.img_type}';"
print("sql:", sql)
self.df_id_index = self.spark.sql(sql).cache()
self.df_id_index.show(10)
print(f"self.df_id_index.count(): {self.df_id_index.count()}")
def handle_data(self):
pass
def save_data(self):
# starrocks_url = "jdbc:mysql://192.168.10.151:19030/selection"
# properties = {
# "user": "fangxingjun",
# "password": "fangxingjun12345",
# "driver": "com.mysql.cj.jdbc.Driver",
# # "driver": "com.mysql.cj.jdbc.Driver",
# }
# self.df_id_index.write.jdbc(url=starrocks_url, table="image_id_index", mode="overwrite", properties=properties)
# self.df_id_index = self.df_id_index.withColumn('created_time', F.lit(datetime.datetime.now()))
# self.df_id_index = self.df_id_index.withColumn("img_type", F.col("img_type").cast("int"))
# StarRocksHelper.spark_export(df_save=self.df_id_index, db_name='selection', table_name='image_id_index')
df_save = self.df_id_index.toPandas()
df_save.to_sql(self.table_save, con=self.engine_doris, if_exists="append", index=False, chunksize=10000)
if __name__ == '__main__':
site_name = sys.argv[1]
img_type = sys.argv[2]
handle_obj = ImgIdIndexToDoris(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
import sys
import time
import traceback
import pandas as pd
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0])) # 上级目录
from utils.templates_mysql import TemplatesMysql
from utils.db_util import DbTypes, DBUtil
class AsinImageLocalPath(object):
def __init__(self, site_name='us', img_type='amazon_self'):
self.site_name = site_name
self.img_type = img_type
self.first_local_dir, self.image_table = self.get_first_local_dir()
# self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
self.engine_doris = DBUtil.get_db_engine(db_type=DbTypes.doris.name, site_name=self.site_name)
self.engine_mysql = DBUtil.get_db_engine(db_type=DbTypes.mysql.name, site_name=self.site_name)
def get_first_local_dir(self):
if self.img_type == 'amazon_self':
first_local_dir = f"/mnt/data/img_data/amazon_self/{self.site_name}"
image_table = f'{self.site_name}_self_asin_image'
elif self.img_type == 'amazon':
first_local_dir = f"/mnt/data/img_data/amazon/{self.site_name}"
image_table = f'{self.site_name}_amazon_image'
elif self.img_type == 'amazon_inv':
first_local_dir = f"/mnt/data/img_data/amazon_inv/{self.site_name}"
image_table = f'{self.site_name}_inv_img'
else:
first_local_dir = ""
image_table = ""
return first_local_dir, image_table
def read_data(self):
# sql = f"select img_unique from selection.{self.image_table} where state=1;"
# df = pd.read_sql(sql, con=self.engine_doris)
sql = f"SELECT * from us_inv_img_info WHERE updated_at>= CURDATE() - INTERVAL 3 DAY and state =3;"
df = pd.read_sql(sql, con=self.engine_mysql)
print(f"sql: {sql}", df.shape)
return df
def handle_data(self, df):
# f"{id_segment}_{id}_{img_id}_{img_type}.jpg"
# df['img_unique'] = df['id_segment'] + "_" + df['id'] + "_" + df['img_id'] + "_" + df['img_type']
df['img_unique'] = df['id_segment'].astype(str) + "_" + df['id'].astype(str) + "_" + df['img_id'].astype(str) + "_" + df['img_type']
if self.img_type in ['amazon_self', 'amazon']:
df["local_path"] = df.img_unique.apply(lambda x: f"{self.first_local_dir}/{x[:1]}/{x[:2]}/{x[:3]}/{x[:4]}/{x[:5]}/{x[:6]}/{x}.jpg")
elif self.img_type in ['amazon_inv']:
df["local_path"] = df.img_unique.apply(lambda x: f"{self.first_local_dir}/{x.split('_')[0]}/{x}.jpg")
df["img_type"] = self.img_type
df["site_name"] = self.site_name
df["state"] = 1
df = df.loc[:, ["img_unique", "site_name", "local_path", "img_type", "state"]]
print(f"此次更新图片的数量: {df.shape}", df.head(5))
# quit()
return df
def save_data(self, df):
df.to_sql("img_local_path", con=self.engine_doris, if_exists="append", index=False)
def run(self):
df = self.read_data()
df = self.handle_data(df)
self.save_data(df)
if __name__ == '__main__':
site_name = sys.argv[1] # 参数1:站点
img_type = sys.argv[2] # 参数2:图片来源类型
# site_name = 'us'
# img_type = "amazon_inv"
handle_obj = AsinImageLocalPath(site_name=site_name, img_type=img_type)
handle_obj.run()
\ No newline at end of file
import os
import pandas as pd
from flask import Flask, request, jsonify
from img_search import ImgSearch
import time
img_search = ImgSearch()
app = Flask(__name__)
UPLOAD_FOLDER = '/mnt/data/img_data/tmp_img/'
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
@app.route('/img_search/similarity', methods=['POST'])
def search_amazon_similarity():
print(request.form) # 打印所有收到的表单数据
print(request.files) # 打印所有收到的文件
file_list = ['file1', 'file2']
files = []
for file_key in file_list:
if file_key in request.files:
files.append(request.files[file_key])
if len(files) != 2:
return jsonify({"error": "只能上传 2 个文件"}), 400
else:
# file1_path
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{files[0].filename}"
file1_path = os.path.join(UPLOAD_FOLDER, filename)
files[0].save(file1_path)
# file2_path
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{files[1].filename}"
file2_path = os.path.join(UPLOAD_FOLDER, filename)
files[1].save(file2_path)
similarity = img_search.cosine_similarity_two_img(
file1_path=file1_path,
file2_path=file2_path
)
return jsonify(
{"similarity": similarity}
)
@app.route('/img_search', methods=['POST'])
def search_amazon():
print(request.form) # 打印所有收到的表单数据
print(request.files) # 打印所有收到的文件
site_name = request.form.get('site_name', 'us') # 默认值为 'us'
img_type = request.form.get('img_type', 'amazon_inv') # 默认值为 'amazon_inv'
# search_key = request.form.get('search_key', 'asin') # 默认值为 'asin'
search_key = request.form.get('search_key', 'img_unique') # 默认值为 'img_unique' -- file
search_value = request.form.get('search_value', 'B0BBNQCXZL') # 默认值为 'B0BBNQCXZL' --
top_k = int(request.form.get('top_k', 100)) # 默认值为 100
df_similarities_list = []
# 获取文件并存储
if search_key == 'file':
# file = request.files['file'] # 获取名为 "file" 的文件流参数
file_list = ['file1', 'file2', 'file3', 'file4', 'file5']
files = []
for file_key in file_list:
if file_key in request.files:
files.append(request.files[file_key])
# 如果传入的文件超过 5 个,可以考虑返回错误或警告
if len(files) > 5:
return jsonify({"error": "最多只能上传 5 个文件"}), 400
# files = request.files.getlist('file') # 获取名为 "file" 的多个文件
for file in files:
timestamp = time.time()
milli_timestamp = round(timestamp * 1000) # 将秒级的时间戳转换为毫秒级
filename = f"{milli_timestamp}_{file.filename}"
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
# 打印日志,便于调试
print("file saved at:", file_path)
# 使用每张图片进行搜索
df_similarities = img_search.search_api(
site_name=site_name, img_type=img_type, search_key=search_key,
search_value=search_value, top_k=top_k, file_path=file_path
)
df_similarities_list.append(df_similarities)
# similarities_dict_list.append({
# "similarities_dict": similarities_dict
# })
else:
file_path = ''
# 使用每张图片进行搜索
df_similarities = img_search.search_api(
site_name=site_name, img_type=img_type, search_key=search_key,
search_value=search_value, top_k=top_k, file_path=file_path
)
df_similarities_list.append(df_similarities)
# similarities_dict_list.append({
# "similarities_dict": similarities_dict
# })
df_result = pd.concat(df_similarities_list)
df_result = df_result.loc[~(df_result.similarity.isna())] # 过滤相似度为NaN的记录
print(f"df_result.columns: {df_result.columns}")
if df_result.shape[0]:
df_result.sort_values(["similarity"], ascending=[False], inplace=True)
# 新增 id 列,从 1 开始
df_result['id'] = range(1, len(df_result) + 1)
df_result = df_result.iloc[: top_k]
result = []
for _, row in df_result.iterrows():
row_dict = {'id': row['id']} # 创建一个字典,首先将 id 添加为 key
for col in df_result.columns:
if col != 'id': # 排除 id 列
row_dict[col] = row[col] # 将其他列添加到字典中
result.append(row_dict)
return jsonify(
result
)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=10001)
import requests
# 服务器地址
url = 'http://192.168.10.217:10001/img_search'
# 文件路径列表,上传多个文件
file_paths = [
'D:\Amazon-Selection\pyspark_job\image_search\img/1.jpg',
# 'D:\Amazon-Selection\pyspark_job\image_search\img/1.png',
'D:\Amazon-Selection\pyspark_job\image_search\img/2.png',
]
# 将多个文件添加到请求中
files = [('file', (open(file_path, 'rb'))) for file_path in file_paths]
# 其他表单数据
data = {
'site_name': 'us',
'img_type': 'amazon_inv',
'search_key': 'file', # 使用文件方式进行查询
'search_value': '', # 在图片查询中可忽略
'top_k': 5 # 设置查询结果的返回数量
}
# 发送请求
response = requests.post(url, files=files, data=data)
# 打印响应
if response.status_code == 200:
print(response.json())
else:
print(f"Error: {response.status_code}, {response.text}")
"""
提取图片特征信息的模型:vgg16
"""
import io
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from numpy import linalg as LA
from keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling, include_top=False)
# self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
# 提取vgg16最后一层卷积特征
def vgg_extract_feat(self, file, file_type='file'):
# if file_type == 'bytes':
# img = image.load_img(io.BytesIO(file.read()), target_size=(self.input_shape[0], self.input_shape[1]))
# else:
# # file_type = 'file'
# img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.load_img(file, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
# print(feat.shape)
norm_feat = feat[0] / LA.norm(feat[0])
# print("norm_feat:", norm_feat)
return list(norm_feat)
# # 提取vgg16最后一层卷积特征
# def vgg_extract_feat(self, image_bytes):
# try:
# img = image.load_img(io.BytesIO(image_bytes), target_size=(self.input_shape[0], self.input_shape[1]))
# img = image.img_to_array(img)
# img = np.expand_dims(img, axis=0)
# img = preprocess_input_vgg(img)
# feat = self.model_vgg.predict(img)
# norm_feat = feat[0] / LA.norm(feat[0])
# return norm_feat.tolist()
# except Exception as e:
# print(e, traceback.format_exc())
# return list(np.zeros(shape=(512,)))
if __name__ == '__main__':
handle_obj = VGGNet()
handle_obj.vgg_extract_feat(file='')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment