asin_image_features.py 6.69 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
import os
import sys
import threading
import time
import traceback
import socket
import uuid

import numpy as np
import pandas as pd
import redis
import logging

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
sys.path.append(os.path.dirname(sys.path[0]))  # 上级目录
from utils.templates import Templates
from sqlalchemy import text
from vgg_model import VGGNet
from utils.db_util import DbTypes, DBUtil
logging.basicConfig(format='%(asctime)s %(name)s %(levelname)s %(message)s', level=logging.INFO)


class AsinImageFeatures(Templates):

    def __init__(self, site_name='us', asin_type=1, thread_num=10, limit=1000):
        super(AsinImageFeatures, self).__init__()
        self.site_name = site_name
        self.asin_type = asin_type
        self.thread_num = thread_num
        self.limit = limit
        self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
        self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
        self.local_name = f"{self.site_name}_asin_image_features"
        self.vgg_model = VGGNet()
        self.hostname = socket.gethostname()

    def acquire_lock(self, lock_name, timeout=10):
        """
        尝试获取分布式锁, 能正常设置锁的话返回True, 不能设置锁的话返回None
        lock_name: 锁的key, 建议和任务名称保持一致
        """
        lock_value = str(uuid.uuid4())
        lock_acquired = self.client.set(lock_name, lock_value, nx=True, ex=timeout)  # 可以不设置超时时间
        # lock_acquired = self.client.set(lock_name, lock_value, nx=True)
        return lock_acquired, lock_value

    def release_lock(self, lock_name, lock_value):
        """释放分布式锁"""
        script = """
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
        """
        result = self.client.eval(script, 1, lock_name, lock_value)
        return result

    def read_data(self):
        while True:
            try:
                lock_acquired, lock_value = self.acquire_lock(lock_name=self.local_name)
                if lock_acquired:
                    print("self.hostname:", self.hostname)
                    with self.engine_srs.begin() as conn:
                        sql_read = text(f"SELECT id, asin, local_path, asin_type FROM selection.asin_image_local_path WHERE site_name='{self.site_name}' and asin_type={self.asin_type} and state=1 LIMIT {self.limit};")
                        # result = conn.execute(sql_read)
                        # df = pd.DataFrame(result.fetchall())
                        df = pd.read_sql(sql=sql_read, con=self.engine_srs)
                        id_list = list(df.id)
                        print(f"sql_read: {sql_read}, {df.shape}", id_list[:10])
                        if id_list:
                            sql_update = text(f"UPDATE selection.asin_image_local_path SET state=2 WHERE id IN ({','.join(map(str, id_list))});")
                            print("sql_update:", sql_update)
                            conn.execute(sql_update)
                    self.release_lock(lock_name=self.local_name, lock_value=lock_value)
                    return df
                else:
                    print(f"当前有其它进程占用redis的锁, 等待5秒继续获取数据")
                    time.sleep(5)  # 等待5s继续访问锁
                    continue
            except Exception as e:
                print(f"读取数据错误: {e}", traceback.format_exc())
                time.sleep(5)
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                continue

    def handle_data(self, df, thread_id):
        id_list = list(df.id)
        asin_list = list(df.asin)
        local_path_list = list(df.local_path)
        data_list = []
        for id, asin, local_path in zip(id_list, asin_list, local_path_list):
            index = id_list.index(id)
            print(f"thread_id, index, id, asin, local_path: {thread_id, index, id, asin, local_path}")
            if self.hostname not in ['hadoop5', 'hadoop6', 'hadoop7', 'hadoop8']:
                local_path = local_path.replace("/mnt", "/home")
            try:
                features = self.vgg_model.vgg_extract_feat(file=local_path)
            except Exception as e:
                print(e, traceback.format_exc())
                features = list(np.zeros(shape=(512,)))
            data_list.append([id, asin, str(features), self.asin_type, self.site_name])
        columns = ['id', 'asin', 'features', 'asin_type', 'site_name']
        df_save = pd.DataFrame(data_list, columns=columns)
        return df_save

    def save_data(self, df):
        df.to_sql("asin_image_features", con=self.engine_srs, if_exists="append", index=False)
        with self.engine_srs.begin() as conn:
            id_tuple = tuple(df.id)
            if id_tuple:
                id_tuple_str = f"({id_tuple[0]})" if len(id_tuple) == 1 else f"{id_tuple}"
                sql_update = f"update selection.asin_image_local_path set state=3 where id in {id_tuple_str};"
                print(f"sql_update: {sql_update}")
                conn.execute(sql_update)

    def run(self, thread_id=1):
        while True:
            try:
                df = self.read_data()
                if df.shape[0]:
                    df_save = self.handle_data(df=df, thread_id=thread_id)
                    self.save_data(df=df_save)
                    # break
                else:
                    break
            except Exception as e:
                print(e, traceback.format_exc())
                self.engine_srs = DBUtil.get_db_engine(db_type=DbTypes.srs.name, site_name=self.site_name)
                self.client = redis.Redis(host='192.168.10.224', port=6379, db=9, password='yswg2023')
                self.vgg_model = VGGNet()
                time.sleep(20)
                continue

    def run_thread(self):
        thread_list = []
        for thread_id in range(self.thread_num):
            thread = threading.Thread(target=self.run, args=(thread_id, ))
            thread_list.append(thread)
            thread.start()
        for thread in thread_list:
            thread.join()
        logging.info("所有线程处理完成")


if __name__ == '__main__':
    # handle_obj = PicturesFeatures(self_flag='_self')
    # site_name = int(sys.argv[1])  # 参数1:站点
    site_name = 'us'
    asin_type = 1
    thread_num = 3
    limit = 500
    handle_obj = AsinImageFeatures(site_name=site_name, asin_type=asin_type, thread_num=thread_num, limit=limit)
    # handle_obj.run()
    handle_obj.run_thread()