Spaces:
Running
Running
| import random | |
| import numpy as np | |
| from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode | |
| class CPPDLabelEncode(BaseRecLabelEncode): | |
| """Convert between text-label and text-index.""" | |
| def __init__( | |
| self, | |
| max_text_length, | |
| character_dict_path=None, | |
| use_space_char=False, | |
| ch=False, | |
| # ch_7000=7000, | |
| ignore_index=100, | |
| use_sos=False, | |
| pos_len=False, | |
| **kwargs): | |
| self.use_sos = use_sos | |
| super(CPPDLabelEncode, | |
| self).__init__(max_text_length, character_dict_path, | |
| use_space_char) | |
| self.ch = ch | |
| self.ignore_index = ignore_index | |
| self.pos_len = pos_len | |
| def __call__(self, data): | |
| text = data['label'] | |
| if self.ch: | |
| text, text_node_index, text_node_num = self.encodech(text) | |
| if text is None: | |
| return None | |
| if len(text) > self.max_text_len: | |
| return None | |
| data['length'] = np.array(len(text)) | |
| # text.insert(0, 0) | |
| if self.pos_len: | |
| text_pos_node = [i_ for i_ in range(len(text), -1, -1) | |
| ] + [100] * (self.max_text_len - len(text)) | |
| else: | |
| text_pos_node = [1] * (len(text) + 1) + [0] * ( | |
| self.max_text_len - len(text)) | |
| text.append(0) | |
| text + [0] * (self.max_text_len - len(text)) | |
| text = text + [self.ignore_index | |
| ] * (self.max_text_len + 1 - len(text)) | |
| data['label'] = np.array(text) | |
| data['label_node'] = np.array(text_node_num + text_pos_node) | |
| data['label_index'] = np.array(text_node_index) | |
| # data['label_ctc'] = np.array(ctc_text) | |
| return data | |
| else: | |
| text, text_char_node, ch_order = self.encode(text) | |
| if text is None: | |
| return None | |
| if len(text) > self.max_text_len: | |
| return None | |
| data['length'] = np.array(len(text)) | |
| # text.insert(0, 0) | |
| if self.pos_len: | |
| text_pos_node = [i_ for i_ in range(len(text), -1, -1) | |
| ] + [100] * (self.max_text_len - len(text)) | |
| else: | |
| text_pos_node = [1] * (len(text) + 1) + [0] * ( | |
| self.max_text_len - len(text)) | |
| text.append(0) | |
| text = text + [self.ignore_index | |
| ] * (self.max_text_len + 1 - len(text)) | |
| data['label'] = np.array(text) | |
| data['label_node'] = np.array(text_char_node + text_pos_node) | |
| data['label_order'] = np.array(ch_order) | |
| return data | |
| def add_special_char(self, dict_character): | |
| if self.use_sos: | |
| dict_character = ['<s>', '</s>'] + dict_character | |
| else: | |
| dict_character = ['</s>'] + dict_character | |
| self.num_character = len(dict_character) | |
| return dict_character | |
| def encode(self, text): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| output: | |
| text: concatenated text index for CTCLoss. | |
| [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
| length: length of each text. [batch_size] | |
| """ | |
| if len(text) == 0: | |
| return None, None, None | |
| if self.lower: | |
| text = text.lower() | |
| text_node = [0 for _ in range(self.num_character)] | |
| text_node[0] = 1 | |
| text_list = [] | |
| ch_order = [] | |
| order = 1 | |
| for char in text: | |
| if char not in self.dict: | |
| continue | |
| text_list.append(self.dict[char]) | |
| text_node[self.dict[char]] += 1 | |
| ch_order.append( | |
| [self.dict[char], text_node[self.dict[char]], order]) | |
| order += 1 | |
| no_ch_order = [] | |
| for char in self.character: | |
| if char not in text: | |
| no_ch_order.append([self.dict[char], 1, 0]) | |
| random.shuffle(no_ch_order) | |
| ch_order = ch_order + no_ch_order | |
| ch_order = ch_order[:self.max_text_len + 1] | |
| if len(text_list) == 0 or len(text_list) > self.max_text_len: | |
| return None, None, None | |
| return text_list, text_node, ch_order.sort() | |
| def encodech(self, text): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| output: | |
| text: concatenated text index for CTCLoss. | |
| [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
| length: length of each text. [batch_size] | |
| """ | |
| if len(text) == 0: | |
| return None, None, None | |
| if self.lower: | |
| text = text.lower() | |
| text_node_dict = {} | |
| text_node_dict.update({0: 1}) | |
| character_index = [_ for _ in range(self.num_character)] | |
| text_list = [] | |
| for char in text: | |
| if char not in self.dict: | |
| continue | |
| i_c = self.dict[char] | |
| text_list.append(i_c) | |
| if i_c in text_node_dict.keys(): | |
| text_node_dict[i_c] += 1 | |
| else: | |
| text_node_dict.update({i_c: 1}) | |
| for ic in list(text_node_dict.keys()): | |
| character_index.remove(ic) | |
| none_char_index = random.sample(character_index, | |
| 37 - len(list(text_node_dict.keys()))) | |
| for ic in none_char_index: | |
| text_node_dict[ic] = 0 | |
| text_node_index = sorted(text_node_dict) | |
| text_node_num = [text_node_dict[k] for k in text_node_index] | |
| if len(text_list) == 0 or len(text_list) > self.max_text_len: | |
| return None, None, None | |
| return text_list, text_node_index, text_node_num | |