Spaces:
Running
Running
| import io | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from .abinet_label_encode import ABINetLabelEncode | |
| from .ar_label_encode import ARLabelEncode | |
| from .ce_label_encode import CELabelEncode | |
| from .char_label_encode import CharLabelEncode | |
| from .cppd_label_encode import CPPDLabelEncode | |
| from .ctc_label_encode import CTCLabelEncode | |
| from .ep_label_encode import EPLabelEncode | |
| from .igtr_label_encode import IGTRLabelEncode | |
| from .mgp_label_encode import MGPLabelEncode | |
| from .rec_aug import ABINetAug | |
| from .rec_aug import BaseDataAugmentation as BDA | |
| from .rec_aug import PARSeqAug, PARSeqAugPIL, SVTRAug | |
| from .resize import (ABINetResize, CDistNetResize, LongResize, RecTVResize, | |
| RobustScannerRecResizeImg, SliceResize, SliceTVResize, | |
| SRNRecResizeImg, SVTRResize, VisionLANResize, | |
| RecDynamicResize) | |
| from .smtr_label_encode import SMTRLabelEncode | |
| from .srn_label_encode import SRNLabelEncode | |
| from .visionlan_label_encode import VisionLANLabelEncode | |
| from .cam_label_encode import CAMLabelEncode | |
| class KeepKeys(object): | |
| def __init__(self, keep_keys, **kwargs): | |
| self.keep_keys = keep_keys | |
| def __call__(self, data): | |
| data_list = [] | |
| for key in self.keep_keys: | |
| data_list.append(data[key]) | |
| return data_list | |
| def transform(data, ops=None): | |
| """transform.""" | |
| if ops is None: | |
| ops = [] | |
| for op in ops: | |
| data = op(data) | |
| if data is None: | |
| return None | |
| return data | |
| class Fasttext(object): | |
| def __init__(self, path='None', **kwargs): | |
| # pip install fasttext==0.9.1 | |
| import fasttext | |
| self.fast_model = fasttext.load_model(path) | |
| def __call__(self, data): | |
| label = data['label'] | |
| fast_label = self.fast_model[label] | |
| data['fast_label'] = fast_label | |
| return data | |
| class DecodeImage(object): | |
| """decode image.""" | |
| def __init__(self, | |
| img_mode='RGB', | |
| channel_first=False, | |
| ignore_orientation=False, | |
| **kwargs): | |
| self.img_mode = img_mode | |
| self.channel_first = channel_first | |
| self.ignore_orientation = ignore_orientation | |
| def __call__(self, data): | |
| img = data['image'] | |
| assert type(img) is bytes and len( | |
| img) > 0, "invalid input 'img' in DecodeImage" | |
| img = np.frombuffer(img, dtype='uint8') | |
| if self.ignore_orientation: | |
| img = cv2.imdecode( | |
| img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR) | |
| else: | |
| img = cv2.imdecode(img, 1) | |
| if img is None: | |
| return None | |
| if self.img_mode == 'GRAY': | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| elif self.img_mode == 'RGB': | |
| assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( | |
| img.shape) | |
| img = img[:, :, ::-1] | |
| if self.channel_first: | |
| img = img.transpose((2, 0, 1)) | |
| data['image'] = img | |
| return data | |
| class DecodeImagePIL(object): | |
| """decode image.""" | |
| def __init__(self, img_mode='RGB', **kwargs): | |
| self.img_mode = img_mode | |
| def __call__(self, data): | |
| img = data['image'] | |
| assert type(img) is bytes and len( | |
| img) > 0, "invalid input 'img' in DecodeImage" | |
| img = data['image'] | |
| buf = io.BytesIO(img) | |
| img = Image.open(buf).convert('RGB') | |
| if self.img_mode == 'Gray': | |
| img = img.convert('L') | |
| elif self.img_mode == 'BGR': | |
| img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序 | |
| img = Image.fromarray(np.uint8(img)) | |
| data['image'] = img | |
| return data | |
| def create_operators(op_param_list, global_config=None): | |
| """create operators based on the config. | |
| Args: | |
| params(list): a dict list, used to create some operators | |
| """ | |
| assert isinstance(op_param_list, list), 'operator config should be a list' | |
| ops = [] | |
| for operator in op_param_list: | |
| assert isinstance(operator, | |
| dict) and len(operator) == 1, 'yaml format error' | |
| op_name = list(operator)[0] | |
| param = {} if operator[op_name] is None else operator[op_name] | |
| if global_config is not None: | |
| param.update(global_config) | |
| op = eval(op_name)(**param) | |
| ops.append(op) | |
| return ops | |
| class GTCLabelEncode(): | |
| """Convert between text-label and text-index.""" | |
| def __init__(self, | |
| gtc_label_encode, | |
| max_text_length, | |
| character_dict_path=None, | |
| use_space_char=False, | |
| **kwargs): | |
| self.gtc_label_encode = eval(gtc_label_encode['name'])( | |
| max_text_length=max_text_length, | |
| character_dict_path=character_dict_path, | |
| use_space_char=use_space_char, | |
| **gtc_label_encode) | |
| self.ctc_label_encode = CTCLabelEncode(max_text_length, | |
| character_dict_path, | |
| use_space_char) | |
| def __call__(self, data): | |
| data_ctc = self.ctc_label_encode({'label': data['label']}) | |
| data = self.gtc_label_encode(data) | |
| if data_ctc is None or data is None: | |
| return None | |
| data['ctc_label'] = data_ctc['label'] | |
| data['ctc_length'] = data_ctc['length'] | |
| return data | |