ImageRecognized.py 7.16 KB
Newer Older
chenyuanjie committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
# author : wangrui
# data : 2024/3/26 14:16

# 新增图片
from flask import Flask, request, jsonify
from googletrans import Translator
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from requests.exceptions import RequestException
import logging
from logging.handlers import TimedRotatingFileHandler


app = Flask("getImageElements")

# 创建日志记录器
logger = logging.getLogger("getImageElements")
logger.setLevel(logging.INFO)

# 创建 TimedRotatingFileHandler
log_filename = "/home/wangrui/logs/app.log"
handler = TimedRotatingFileHandler(log_filename, when="midnight", interval=1, backupCount=7, encoding='utf-8')
# when 表示时间间隔单位,这里设置为每天切割
# interval 表示时间间隔数量,这里设置为 1,即每天切割
# backupCount 表示保留的日志文件个数,这里设置为 7,即保留最近 7 天的日志文件
# encoding 表示日志文件的编码方式

# 设置日志记录格式
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# 添加处理器到日志记录器
logger.addHandler(handler)

# processor = DetrImageProcessor.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
# model = DetrForObjectDetection.from_pretrained(r"E:\Model\model_new\detr-resnet-50", revision="no_timm")
# image_classification = pipeline(Tasks.image_classification,
#                                 model='E:\Model\model_new\cv_nextvit-small_image-classification_Dailylife-labels')
# 模型参数配置
processor = DetrImageProcessor.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained(r"/home/wangrui/detr-resnet-50", revision="no_timm")
image_classification = pipeline(Tasks.image_classification, model='/home/wangrui/cv_nextvit-small_image-classification_Dailylife-labels')


# 获取图片中的元素
@app.route('/image/getImageElements', methods=['POST', 'GET'])
def getImageElement():
    request_data = request.get_json()
    logger.info(f"Request: {request.remote_addr} - {request.method} {request.url}")
    if 'input_data' not in request_data:
        logger.error('No input data provided')
        return jsonify({"error": "No input data provided"}), 400
    input_datas = request_data['input_data']
    url_pre = "http://soundasia.oss-cn-shenzhen.aliyuncs.com/"
    if input_datas:
        try:
            result = []
            for input_data in input_datas:
                if input_data is not None and input_data != "":
                    complete_url = url_pre + input_data
                    # 调用模型接口获取图片元素
                    english_elements, chinese_elements = getElementsFromModel(complete_url)
                    if english_elements is not None and chinese_elements is not None:
                        result.append({
                            "input_data": input_data,
                            "english_elements": english_elements,
                            "chinese_elements": chinese_elements
                        })
                        logger.info(f'图片 {input_data} 元素解析完成')
            return jsonify({"result": result}), 200
        except Exception as e:
            logger.error(f"An error occured: {e}")
            return jsonify({"error": "Image Parse Failed"}), 400
    else:
        logger.error('Image is Empty')
        return jsonify({"error": "Image is Empty"}), 400


# google翻译api
def translate_with_retry(translator, text, src='en', dest='zh-CN', max_retries=5):
    retries = 0
    while retries < max_retries:
        try:
            translation = translator.translate(text, src=src, dest=dest)
            return translation.text
        except RequestException as e:
            logger.error(f"Request failed: {e}")
            retries += 1
            if retries < max_retries:
                logger.warning(f"Retrying... ({retries}/{max_retries})")
            else:
                logger.error('Max retries reached. Translation failed.')
                return None


# 模型获取图片元素接口
def getElementsFromModel(img_url):
    logger.info(f"解析的图片地址为: {img_url}")
    translator = Translator()
    english_label_set, chinese_label_set = getFromModel1(img_url, translator)
    if not english_label_set:
        english_label_set2, chinese_label_set2 = getFromModel2(img_url, translator)
        if english_label_set2:
            return ",".join(english_label_set2), ",".join(chinese_label_set2)
        else:
            return None, None
    else:
        return ",".join(english_label_set), ",".join(chinese_label_set)


# 优先使用detr-resnet-50模型解析,识别出的结果是英文
def getFromModel1(img_url, translator):
    try:
        english_label_set = set()
        chinese_label_set = set()
        # 获取待识别图片
        image = Image.open(requests.get(img_url, stream=True, timeout=30).raw)
        # 统一待识别图片格式为RGB
        if image.mode != "RGB":
            image = image.convert("RGB")
        # 图片输入数据格式统一化
        inputs = processor(images=image, return_tensors="pt")
        # 调用模型进行预测
        outputs = model(**inputs)
        target_sizes = torch.tensor([image.size[::-1]])
        # 设定模型阈值为0.7,即置信度大于0.7的结果认为可信
        results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
        for label in results['labels']:
            label = model.config.id2label[label.item()]
            if label:
                english_label_set.add(label)
                # 通过元素英文得到对应的中文
                chinese_label = translate_with_retry(translator, label, src='en', dest='zh-CN')
                if chinese_label:
                    chinese_label_set.add(chinese_label)
        logger.info("模型1解析完成!")
        return english_label_set, chinese_label_set
    except Exception as e:
        logger.error(f"模型1解析失败! {e}")
        return None, None


# 使用cv_nextvit-small_image-classification_Dailylife-labels模型进行补充,识别出的结果是中文
def getFromModel2(img_url, translator):
    try:
        english_label_set = set()
        chinese_label_set = set()
        # 调用模型进行元素识别
        results = image_classification(img_url)
        for score, label in zip(results['scores'], results['labels']):
            # 设置得分为0.5,即置信度大于0.5的结果认为可信
            if score > 0.5:
                # 通过元素中文得到对应的英文
                english_label = translate_with_retry(translator, label, src='zh-CN', dest='en')
                if english_label:
                    english_label_set.add(english_label)
                chinese_label_set.add(label)
        logger.info("模型2解析完成!")
        return english_label_set, chinese_label_set
    except Exception as e:
        logger.error(f"模型2解析失败! {e}")
        return None, None


if __name__ == '__main__':
    app.run(host="0.0.0.0", port=6699)