File size: 1,795 Bytes
5de2f8f
 
 
 
 
 
ba555c1
 
5de2f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import re
from .ctc_postprocess import BaseRecLabelDecode


rules = [
    (r'-<\|sn\|>', ''),
    (r' <\|sn\|>', ' '),
    (r'<\|sn\|>', ' '),
    (r'<\|unk\|>', ''),
    (r'<s>', ''),
    (r'</s>', ''),
    (r'\uffff', ''),
    (r'_{4,}', '___'),
    (r'\.{4,}', '...'),
]

def clean_special_tokens(text):
    text = text.replace(' ', '').replace('Ġ', ' ').replace(
        'Ċ', '\n').replace('<|bos|>',
                            '').replace('<|eos|>',
                                        '').replace('<|pad|>', '')
    for rule in rules:
        text = re.sub(rule[0], rule[1], text)
    return text

class UniRecLabelDecode(BaseRecLabelDecode):
    """Convert between text-label and text-index."""
    SPACE = '[s]'
    GO = '[GO]'
    list_token = [GO, SPACE]

    def __init__(self,
                 character_dict_path=None,
                 use_space_char=False,
                 tokenizer_path='./configs/rec/unirec/unirec-0.1b',
                 **kwargs):
        super(UniRecLabelDecode, self).__init__(character_dict_path,
                                                use_space_char)
        from transformers import AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    def __call__(self, preds, batch=None, *args, **kwargs):
        result_list = []
        pred_ids = preds
        res = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
        for i in range(len(res)):
            res[i] = clean_special_tokens(res[i])
            result_list.append(
                (res[i],
                 0.0))  # Assuming confidence is not available, set to 0.0
        return result_list

    def add_special_char(self, dict_character):
        dict_character = self.list_token + dict_character
        return dict_character