Recognizer 类
Classes
Recognizer
Methods
__init__(self, model_name='densenet_lite_136-fc', *, cand_alphabet=None, context='cpu', model_fp=None, model_backend='onnx', root='/home/docs/.cnocr', vocab_fp=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/cnocr/checkouts/latest/cnocr/label_cn.txt'), **kwargs)
special
识别模型初始化函数。
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
模型名称。默认为 |
'densenet_lite_136-fc' |
cand_alphabet |
Union[Collection, str] |
待识别字符所在的候选集合。默认为 |
None |
context |
str |
'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 |
'cpu' |
model_fp |
Optional[str] |
如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) |
None |
model_backend |
str |
'pytorch', or 'onnx'。表明预测时是使用 PyTorch 版本模型,还是使用 ONNX 版本模型。 同样的模型,ONNX 版本的预测速度一般是 PyTorch 版本的2倍左右。默认为 'onnx'。 |
'onnx' |
root |
Union[str, pathlib.Path] |
模型文件所在的根目录。
Linux/Mac下默认值为 |
'/home/docs/.cnocr' |
vocab_fp |
Union[str, pathlib.Path] |
字符集合的文件路径,即 |
PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/cnocr/checkouts/latest/cnocr/label_cn.txt') |
**kwargs |
目前未被使用。 |
{} |
Examples:
使用默认参数:
>>> rec = Recognizer()
使用指定模型:
>>> rec = Recognizer(model_name='densenet_lite_136-fc')
识别时只考虑数字:
>>> rec = Recognizer(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
Source code in cnocr/recognizer.py
def __init__(
self,
model_name: str = 'densenet_lite_136-fc',
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
model_fp: Optional[str] = None,
model_backend: str = 'onnx', # ['pytorch', 'onnx']
root: Union[str, Path] = data_dir(),
vocab_fp: Union[str, Path] = VOCAB_FP,
**kwargs,
):
"""
识别模型初始化函数。
Args:
model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`。
此参数仅在 `model_backend=='pytorch'` 时有效。
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
model_backend (str): 'pytorch', or 'onnx'。表明预测时是使用 PyTorch 版本模型,还是使用 ONNX 版本模型。
同样的模型,ONNX 版本的预测速度一般是 PyTorch 版本的2倍左右。默认为 'onnx'。
root (Union[str, Path]): 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
vocab_fp (Union[str, Path]): 字符集合的文件路径,即 `label_cn.txt` 文件路径。
若训练的自有模型更改了字符集,看通过此参数传入新的字符集文件路径。
**kwargs: 目前未被使用。
Examples:
使用默认参数:
>>> rec = Recognizer()
使用指定模型:
>>> rec = Recognizer(model_name='densenet_lite_136-fc')
识别时只考虑数字:
>>> rec = Recognizer(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
"""
model_backend = model_backend.lower()
assert model_backend in ('pytorch', 'onnx')
if 'name' in kwargs:
logger.warning(
'param `name` is useless and deprecated since version %s'
% MODEL_VERSION
)
check_model_name(model_name)
check_context(context)
self._model_name = model_name
self._model_backend = model_backend
if context == 'gpu':
context = 'cuda'
self.context = context
try:
self._assert_and_prepare_model_files(model_fp, root)
except NotImplementedError:
logger.warning(
'no available model is found for name %s and backend %s'
% (self._model_name, self._model_backend)
)
self._model_backend = (
'onnx' if self._model_backend == 'pytorch' else 'pytorch'
)
logger.warning(
'trying to use name %s and backend %s'
% (self._model_name, self._model_backend)
)
self._assert_and_prepare_model_files(model_fp, root)
self._vocab, self._letter2id = read_charset(vocab_fp)
self.postprocessor = CTCPostProcessor(vocab=self._vocab)
self._candidates = None
self.set_cand_alphabet(cand_alphabet)
self._model = self._get_model(context)
ocr_for_single_line(self, img_fp)
Recognize characters from an image with only one-line characters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_fp |
Union[str, pathlib.Path, torch.Tensor, numpy.ndarray] |
image file path; or image torch.Tensor or np.ndarray, with shape [height, width] or [height, width, channel]. The optional channel should be 1 (gray image) or 3 (color image). |
required |
Returns:
Type | Description |
---|---|
Tuple[List[str], float] |
Source code in cnocr/recognizer.py
def ocr_for_single_line(
self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
) -> Tuple[List[str], float]:
"""
Recognize characters from an image with only one-line characters.
Args:
img_fp (Union[str, Path, torch.Tensor, np.ndarray]):
image file path; or image torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
The optional channel should be 1 (gray image) or 3 (color image).
Returns:
tuple: (list of chars, prob), such as (['你', '好'], 0.80)
"""
img = self._prepare_img(img_fp)
res = self.ocr_for_single_lines([img])
return res[0]
recognize(self, img_list, batch_size=1)
Batch recognize characters from a list of one-line-characters images.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img_list |
List[Union[str, pathlib.Path, torch.Tensor, numpy.ndarray]] |
list of images, in which each element should be a line image array, with type torch.Tensor or np.ndarray. Each element should be a tensor with values ranging from 0 to 255, and with shape [height, width] or [height, width, channel]. The optional channel should be 1 (gray image) or 3 (RGB-format color image). 注:img_list 不宜包含太多图片,否则同时导入这些图片会消耗很多内存。 |
required |
batch_size |
int |
待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 |
1 |
Returns:
Type | Description |
---|---|
List[Tuple[str, float]] |
[('第一行', 0.80), ('第二行', 0.75), ('第三行', 0.9)] |
Source code in cnocr/recognizer.py
def recognize(
self,
img_list: List[Union[str, Path, torch.Tensor, np.ndarray]],
batch_size: int = 1,
) -> List[Tuple[str, float]]:
"""
Batch recognize characters from a list of one-line-characters images.
Args:
img_list (List[Union[str, Path, torch.Tensor, np.ndarray]]):
list of images, in which each element should be a line image array,
with type torch.Tensor or np.ndarray.
Each element should be a tensor with values ranging from 0 to 255,
and with shape [height, width] or [height, width, channel].
The optional channel should be 1 (gray image) or 3 (RGB-format color image).
注:img_list 不宜包含太多图片,否则同时导入这些图片会消耗很多内存。
batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。
Returns:
list: list of (chars, prob), such as
[('第一行', 0.80), ('第二行', 0.75), ('第三行', 0.9)]
"""
if len(img_list) == 0:
return []
img_list = [self._prepare_img(img) for img in img_list]
img_list = [self._transform_img(img) for img in img_list]
should_sort = batch_size > 1 and len(img_list) // batch_size > 1
if should_sort:
# 把图片按宽度从小到大排列,提升效率
sorted_idx_list = sorted(
range(len(img_list)), key=lambda i: img_list[i].shape[2]
)
sorted_img_list = [img_list[i] for i in sorted_idx_list]
else:
sorted_idx_list = range(len(img_list))
sorted_img_list = img_list
idx = 0
sorted_out = []
while idx * batch_size < len(sorted_img_list):
imgs = sorted_img_list[idx * batch_size : (idx + 1) * batch_size]
try:
batch_out = self._predict(imgs)
except Exception as e:
# 对于太小的图片,如宽度小于8,会报错
batch_out = {'preds': [([''], 0.0)] * len(imgs)}
sorted_out.extend(batch_out['preds'])
idx += 1
out = [None] * len(sorted_out)
for idx, pred in zip(sorted_idx_list, sorted_out):
out[idx] = pred
res = []
for line in out:
chars, prob = line
chars = [c if c != '<space>' else ' ' for c in chars]
res.append((''.join(chars), prob))
return res
set_cand_alphabet(self, cand_alphabet)
设置待识别字符的候选集合。
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cand_alphabet |
Union[Collection, str] |
待识别字符所在的候选集合。默认为 |
required |
Returns:
Type | Description |
---|---|
None |
Source code in cnocr/recognizer.py
def set_cand_alphabet(self, cand_alphabet: Optional[Union[Collection, str]]):
"""
设置待识别字符的候选集合。
Args:
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
Returns:
None
"""
if cand_alphabet is None:
self._candidates = None
else:
cand_alphabet = [
word if word != ' ' else '<space>' for word in cand_alphabet
]
excluded = set(
[word for word in cand_alphabet if word not in self._letter2id]
)
if excluded:
logger.warning(
'chars in candidates are not in the vocab, ignoring them: %s'
% excluded
)
candidates = [word for word in cand_alphabet if word in self._letter2id]
self._candidates = None if len(candidates) == 0 else candidates
logger.debug('candidate chars: %s' % self._candidates)