asin_image_features.py 6.69 KB
import os
import sys
import threading
import time
import traceback
import socket
import uuid

import numpy as np
import pandas as pd
import redis
import logging

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
from sqlalchemy import text
from vgg_model import VGGNet
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)


class AsinImageFeatures(Templates):

    def __init__(self, site_name='us', asin_type=1, thread_num=10, limit=1000):
        super(AsinImageFeatures, self).__init__()
        self.site_name = site_name
        self.asin_type = asin_type
        self.thread_num = thread_num
        self.limit = limit
        self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
        self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
        self.local_name = f"{self.site_name}_asin_image_features"
        self.vgg_model = VGGNet()
        self.hostname = socket.gethostname()

    def acquire_lock(self, lock_name, timeout=10):
        """
        尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
        lock_name: 锁的key, 建议和任务名称保持一致
        """
        lock_value = str(uuid.uuid4())
        lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout)  # 可以不设置超时时间
        # lock_acquired = self.client.set(lock_name, lock_value, nx=True)
        return lock_acquired, lock_value

    def release_lock(self, lock_name, lock_value):
        """释放分布式锁"""
        script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        result = self.client.eval(script, 1, lock_name, lock_value)
        return result

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
                if lock_acquired:
                    print("self.hostname:", self.hostname)
                    with self.engine_srs.begin() as conn:
                        sql_read = text(f"SELECT id, asin, local_path, asin_type FROM selection.asin_image_local_path WHERE site_name='{self.site_name}' and asin_type={self.asin_type} and state=1 LIMIT {self.limit};")
                        # result = conn.execute(sql_read)
                        # df = pd.DataFrame(result.fetchall())
                        df = pd.read_sql(sql=sql_read, con=self.engine_srs)
                        id_list = list(df.id)
                        print(f"sql_read: {sql_read}, {df.shape}", id_list[:10])
                        if id_list:
                            sql_update = text(f"UPDATE selection.asin_image_local_path SET state=2 WHERE id IN ({','.join(map(str, id_list))});")
                            print("sql_update:", sql_update)
                            conn.execute(sql_update)
                    self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                    return df
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(5)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(5)
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                continue

    def handle_data(self, df, thread_id):
        id_list = list(df.id)
        asin_list = list(df.asin)
        local_path_list = list(df.local_path)
        data_list = []
        for id, asin, local_path in zip(id_list, asin_list, local_path_list):
            index = id_list.index(id)
            print(f"thread_id, index, id, asin, local_path: {thread_id, index, id, asin, local_path}")
            if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
                local_path = local_path.replace("/mnt", "/home")
            try:
                features = self.vgg_model.vgg_extract_feat(file=local_path)
            except Exception as e:
                print(e, traceback.format_exc())
                features = list(np.zeros(shape=(512,)))
            data_list.append([id, asin, str(features), self.asin_type, self.site_name])
        columns = ['id', 'asin', 'features', 'asin_type', 'site_name']
        df_save = pd.DataFrame(data_list, columns=columns)
        return df_save

    def save_data(self, df):
        df.to_sql("asin_image_features", con=self.engine_srs, if_exists="append", index=False)
        with self.engine_srs.begin() as conn:
            id_tuple = tuple(df.id)
            if id_tuple:
                id_tuple_str = f"({id_tuple[0]})" if len(id_tuple) == 1 else f"{id_tuple}"
                sql_update = f"update selection.asin_image_local_path set state=3 where id in {id_tuple_str};"
                print(f"sql_update: {sql_update}")
                conn.execute(sql_update)

    def run(self, thread_id=1):
        while True:
            try:
                df = self.read_data()
                if df.shape[0]:
                    df_save = self.handle_data(df=df, thread_id=thread_id)
                    self.save_data(df=df_save)
                    # break
                else:
                    break
            except Exception as e:
                print(e, traceback.format_exc())
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                self.vgg_model = VGGNet()
                time.sleep(20)
                continue

    def run_thread(self):
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run, args=(thread_id, ))
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")


if __name__ == '__main__':
    # handle_obj = PicturesFeatures(self_flag='_self')
    # site_name = int(sys.argv[1])  # 参数1:站点
    site_name = 'us'
    asin_type = 1
    thread_num = 3
    limit = 500
    handle_obj = AsinImageFeatures(site_name=site_name, asin_type=asin_type, thread_num=thread_num, limit=limit)
    # handle_obj.run()
    handle_obj.run_thread()