OpenOCR-UniRec-Demo / tools /infer_rec.py
duyongkun
update app
5de2f8f
import os
from pathlib import Path
import sys
import time
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import numpy as np
from tools.engine.config import Config
from tools.utility import ArgsParser
from tools.utils.logging import get_logger
from tools.utils.utility import get_image_file_list
logger = get_logger()
root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
'../configs/rec/svtrv2/svtrv2_ch.yml')
DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt')
MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
MODEL_NAME_REC_ONNX = './openocr_rec_model.onnx' # 模型文件名称
DOWNLOAD_URL_REC_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_rec_model.onnx' # 模型文件 URL
def check_and_download_model(model_name: str, url: str):
"""
检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
Args:
model_name (str): 模型文件的名称,例如 "model.pt"
url (str): 模型文件的下载地址
Returns:
str: 模型文件的完整路径
"""
if os.path.exists(model_name):
return model_name
# 固定缓存路径为用户主目录下的 ".cache/openocr"
cache_dir = Path.home() / '.cache' / 'openocr'
model_path = cache_dir / model_name
# 如果模型文件已存在,直接返回路径
if model_path.exists():
logger.info(f'Model already exists at: {model_path}')
return str(model_path)
# 如果文件不存在,下载模型
logger.info(f'Model not found. Downloading from {url}...')
# 创建缓存目录(如果不存在)
cache_dir.mkdir(parents=True, exist_ok=True)
try:
# 下载文件
import urllib.request
with urllib.request.urlopen(url) as response, open(model_path,
'wb') as out_file:
out_file.write(response.read())
logger.info(f'Model downloaded and saved at: {model_path}')
return str(model_path)
except Exception as e:
logger.error(f'Error downloading the model: {e}')
# 提示用户手动下载
logger.error(
f'Unable to download the model automatically. '
f'Please download the model manually from the following URL:\n{url}\n'
f'and save it to: {model_name} or {model_path}')
raise RuntimeError(
f'Failed to download the model. Please download it manually from {url} '
f'and save it to {model_path}') from e
class RatioRecTVReisze(object):
def __init__(self, cfg):
self.max_ratio = cfg['Eval']['loader'].get('max_ratio', 12)
self.base_shape = cfg['Eval']['dataset'].get(
'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
self.base_h = cfg['Eval']['dataset'].get('base_h', 32)
from torchvision import transforms as T
from torchvision.transforms import functional as F
self.F = F
self.interpolation = T.InterpolationMode.BICUBIC
transforms = []
transforms.extend([
T.ToTensor(),
T.Normalize(0.5, 0.5),
])
self.transforms = T.Compose(transforms)
self.ceil = cfg['Eval']['dataset'].get('ceil', False),
def __call__(self, data):
img = data['image']
imgH = self.base_h
w, h = img.size
if self.ceil:
gen_ratio = int(float(w) / float(h)) + 1
else:
gen_ratio = max(1, round(float(w) / float(h)))
ratio_resize = min(gen_ratio, self.max_ratio)
imgW, imgH = self.base_shape[ratio_resize -
1] if ratio_resize <= 4 else [
self.base_h *
ratio_resize, self.base_h
]
resized_w = imgW
resized_image = self.F.resize(img, (imgH, resized_w),
interpolation=self.interpolation)
img = self.transforms(resized_image)
data['image'] = img
return data
def build_rec_process(cfg):
transforms = []
ratio_resize_flag = True
for op in cfg['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Resize' in op_name:
ratio_resize_flag = False
if 'Label' in op_name:
continue
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
if cfg['Architecture']['algorithm'] in ['SAR', 'RobustScanner']:
if 'valid_ratio' in op[op_name]['keep_keys']:
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
else:
op[op_name]['keep_keys'] = ['image']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
return transforms, ratio_resize_flag
def set_device(device, numId=0):
import torch
if device == 'gpu' and torch.cuda.is_available():
device = torch.device(f'cuda:{numId}')
else:
logger.info('GPU is not available, using CPU.')
device = torch.device('cpu')
return device
class OpenRecognizer:
def __init__(self,
config=None,
mode='mobile',
backend='torch',
onnx_model_path=None,
numId=0):
"""
Args:
config (dict, optional): 配置信息。默认为None。
mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
backend (str): 'torch' 或 'onnx'
onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
numId (int, optional): 设备编号。默认为0。
"""
if config is None:
config_file = DEFAULT_CFG_PATH_REC_SERVER if mode == 'server' else DEFAULT_CFG_PATH_REC
config = Config(config_file).cfg
self.cfg = config
# 公共初始化
self._init_common()
backend = backend if config['Global'].get(
'backend', None) is None else config['Global']['backend']
self.backend = backend
if backend == 'torch':
import torch
self.torch = torch
self._init_torch_model(numId)
elif backend == 'onnx':
from tools.infer.onnx_engine import ONNXEngine
onnx_model_path = onnx_model_path if config['Global'].get(
'onnx_model_path',
None) is None else config['Global']['onnx_model_path']
if not onnx_model_path:
if self.cfg['Architecture']['algorithm'] == 'SVTRv2_mobile':
onnx_model_path = check_and_download_model(
MODEL_NAME_REC_ONNX, DOWNLOAD_URL_REC_ONNX)
else:
raise ValueError('ONNX模式需要指定onnx_model_path参数')
self.onnx_rec_engine = ONNXEngine(
onnx_model_path, use_gpu=config['Global']['device'] == 'gpu')
else:
raise ValueError("backend参数必须是'torch'或'onnx'")
def _init_common(self):
# 初始化公共组件
from openrec.postprocess import build_post_process
from openrec.preprocess import create_operators, transform
self.transform = transform
# 构建预处理流程
algorithm_name = self.cfg['Architecture']['algorithm']
if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
self.cfg['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
self.post_process_class = build_post_process(self.cfg['PostProcess'],
self.cfg['Global'])
char_num = self.post_process_class.get_character_num()
self.cfg['Architecture']['Decoder']['out_channels'] = char_num
transforms, ratio_resize_flag = build_rec_process(self.cfg)
self.ops = create_operators(transforms, self.cfg['Global'])
if ratio_resize_flag:
ratio_resize = RatioRecTVReisze(cfg=self.cfg)
self.ops.insert(-1, ratio_resize)
def _init_torch_model(self, numId):
from tools.utils.ckpt import load_ckpt
if self.cfg['Global'].get('use_transformers', False):
algorithm_name = 'unirec' # 使用transformers模型
from openrec.modeling.transformers_modeling.modeling_unirec import UniRecForConditionalGenerationNew
from openrec.modeling.transformers_modeling.configuration_unirec import UniRecConfig
cfg_model = UniRecConfig.from_pretrained(self.cfg['Global']['vlm_ocr_config'])
# cfg_model._attn_implementation = "flash_attention_2"
cfg_model._attn_implementation = 'eager'
self.model = UniRecForConditionalGenerationNew(config=cfg_model)
else:
# PyTorch专用初始化
algorithm_name = self.cfg['Architecture']['algorithm']
if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']:
if not os.path.exists(self.cfg['Global']['pretrained_model']):
pretrained_model = check_and_download_model(
MODEL_NAME_REC, DOWNLOAD_URL_REC
) if algorithm_name == 'SVTRv2_mobile' else check_and_download_model(
MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
self.cfg['Global']['pretrained_model'] = pretrained_model
from openrec.modeling import build_model as build_rec_model
self.model = build_rec_model(self.cfg['Architecture'])
load_ckpt(self.model, self.cfg)
self.device = set_device(self.cfg['Global']['device'], numId)
self.model.to(self.device)
self.model.eval()
if algorithm_name == 'SVTRv2_mobile':
from tools.infer_det import replace_batchnorm
replace_batchnorm(self.model.encoder)
def _inference_onnx(self, images):
# ONNX输入需要为numpy数组
return self.onnx_rec_engine.run(images)
def __call__(self,
img_path=None,
img_numpy_list=None,
img_numpy=None,
batch_num=1):
"""
调用函数,处理输入图像,并返回识别结果。
Args:
img_path (str, optional): 图像文件的路径。默认为 None。
img_numpy_list (list, optional): 包含多个图像 numpy 数组的列表。默认为 None。
img_numpy (numpy.ndarray, optional): 单个图像的 numpy 数组。默认为 None。
batch_num (int, optional): 每次处理的图像数量。默认为 1。
Returns:
list: 包含识别结果的列表,每个元素为一个字典,包含文件路径(如果有的话)、文本、分数和延迟时间。
Raises:
Exception: 如果没有提供图像路径或 numpy 数组,则引发异常。
"""
if img_numpy is not None:
img_numpy_list = [img_numpy]
num_img = 1
elif img_path is not None:
img_path = get_image_file_list(img_path)
num_img = len(img_path)
elif img_numpy_list is not None:
num_img = len(img_numpy_list)
else:
raise Exception('No input image path or numpy array.')
results = []
for start_idx in range(0, num_img, batch_num):
batch_data = []
batch_others = []
batch_file_names = []
max_width, max_height = 0, 0
# Prepare batch data
for img_idx in range(start_idx, min(start_idx + batch_num,
num_img)):
if img_numpy_list is not None:
img = img_numpy_list[img_idx]
data = {'image': img}
elif img_path is not None:
file_name = img_path[img_idx]
with open(file_name, 'rb') as f:
img = f.read()
data = {'image': img}
data = self.transform(data, self.ops[:1])
batch_file_names.append(file_name)
batch = self.transform(data, self.ops[1:])
others = None
if self.cfg['Architecture']['algorithm'] in [
'SAR', 'RobustScanner'
]:
valid_ratio = np.expand_dims(batch[-1], axis=0)
batch_others.append(valid_ratio)
resized_image = batch[0] if isinstance(
batch[0], np.ndarray) else batch[0].numpy()
h, w = resized_image.shape[-2:]
max_width = max(max_width, w)
max_height = max(max_height, h)
batch_data.append(batch[0])
padded_batch = np.zeros(
(len(batch_data), 3, max_height, max_width), dtype=np.float32)
for i, img in enumerate(batch_data):
h, w = img.shape[-2:]
padded_batch[i, :, :h, :w] = img
if batch_others:
others = np.concatenate(batch_others, axis=0)
else:
others = None
t_start = time.time()
if self.backend == 'torch':
images = self.torch.from_numpy(padded_batch).to(
device=self.device)
with self.torch.no_grad():
if self.cfg['Global'].get('use_transformers', False):
# transformers模型推理
inputs = {
'pixel_values': images,
'input_ids': None,
'attention_mask': None
}
preds = self.model.generate(**inputs)
else:
# PyTorch模型推理
preds = self.model(images,
others) # bs, len, num_classes
torch_tensor = True
elif self.backend == 'onnx':
# ONNX推理
preds = self._inference_onnx(padded_batch)
preds = preds[0] # bs, len, num_classes
torch_tensor = False
t_cost = time.time() - t_start
post_results = self.post_process_class(preds,
torch_tensor=torch_tensor)
for i, post_result in enumerate(post_results):
if img_path is not None:
info = {
'file': batch_file_names[i],
'text': post_result[0],
'score': post_result[1],
'elapse': t_cost
}
else:
info = {
'text': post_result[0],
'score': post_result[1],
'elapse': t_cost
}
results.append(info)
return results
def main(cfg):
model = OpenRecognizer(cfg)
save_res_path = './rec_results/'
if not os.path.exists(save_res_path):
os.makedirs(save_res_path)
t_sum = 0
sample_num = 0
max_len = cfg['Global']['max_text_length']
text_len_time = [0 for _ in range(max_len)]
text_len_num = [0 for _ in range(max_len)]
sample_num = 0
with open(save_res_path + '/rec_results.txt', 'wb') as fout:
for file in get_image_file_list(cfg['Global']['infer_img']):
preds_result = model(img_path=file, batch_num=1)[0]
rec_text = preds_result['text']
score = preds_result['score']
t_cost = preds_result['elapse']
info = rec_text + '\t' + str(score)
text_len_num[min(max_len - 1, len(rec_text))] += 1
text_len_time[min(max_len - 1, len(rec_text))] += t_cost
logger.info(
f'{sample_num} {file}\t result: {info}, time cost: {t_cost}')
otstr = file + '\t' + info + '\n'
t_sum += t_cost
fout.write(otstr.encode())
sample_num += 1
logger.info(
f"Results saved to {os.path.join(save_res_path, 'rec_results.txt')}.)"
)
print(text_len_num)
w_avg_t_cost = []
for l_t_cost, l_num in zip(text_len_time, text_len_num):
if l_num != 0:
w_avg_t_cost.append(l_t_cost / l_num)
print(w_avg_t_cost)
w_avg_t_cost = sum(w_avg_t_cost) / len(w_avg_t_cost)
logger.info(
f'Sample num: {sample_num}, Weighted Avg time cost: {t_sum/sample_num}, Avg time cost: {w_avg_t_cost}'
)
logger.info('success!')
if __name__ == '__main__':
FLAGS = ArgsParser().parse_args()
cfg = Config(FLAGS.config)
FLAGS = vars(FLAGS)
opt = FLAGS.pop('opt')
cfg.merge_dict(FLAGS)
cfg.merge_dict(opt)
main(cfg.cfg)