Spaces:
Running
Running
| import torch.nn as nn | |
| __all__ = ['build_decoder'] | |
| def build_decoder(config): | |
| # rec decoder | |
| from .abinet_decoder import ABINetDecoder | |
| from .aster_decoder import ASTERDecoder | |
| from .cdistnet_decoder import CDistNetDecoder | |
| from .cppd_decoder import CPPDDecoder | |
| from .rctc_decoder import RCTCDecoder | |
| from .ctc_decoder import CTCDecoder | |
| from .dan_decoder import DANDecoder | |
| from .igtr_decoder import IGTRDecoder | |
| from .lister_decoder import LISTERDecoder | |
| from .lpv_decoder import LPVDecoder | |
| from .mgp_decoder import MGPDecoder | |
| from .nrtr_decoder import NRTRDecoder | |
| from .parseq_decoder import PARSeqDecoder | |
| from .robustscanner_decoder import RobustScannerDecoder | |
| from .sar_decoder import SARDecoder | |
| from .smtr_decoder import SMTRDecoder | |
| from .smtr_decoder_nattn import SMTRDecoderNumAttn | |
| from .srn_decoder import SRNDecoder | |
| from .visionlan_decoder import VisionLANDecoder | |
| from .matrn_decoder import MATRNDecoder | |
| from .cam_decoder import CAMDecoder | |
| from .ote_decoder import OTEDecoder | |
| from .bus_decoder import BUSDecoder | |
| support_dict = [ | |
| 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', | |
| 'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder', | |
| 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', | |
| 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', | |
| 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', | |
| 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' | |
| ] | |
| module_name = config.pop('name') | |
| assert module_name in support_dict, Exception( | |
| 'decoder only support {}'.format(support_dict)) | |
| module_class = eval(module_name)(**config) | |
| return module_class | |
| class GTCDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| gtc_decoder, | |
| ctc_decoder, | |
| detach=True, | |
| infer_gtc=False, | |
| out_channels=0, | |
| **kwargs): | |
| super(GTCDecoder, self).__init__() | |
| self.detach = detach | |
| self.infer_gtc = infer_gtc | |
| if infer_gtc: | |
| gtc_decoder['out_channels'] = out_channels[0] | |
| ctc_decoder['out_channels'] = out_channels[1] | |
| gtc_decoder['in_channels'] = in_channels | |
| ctc_decoder['in_channels'] = in_channels | |
| self.gtc_decoder = build_decoder(gtc_decoder) | |
| else: | |
| ctc_decoder['in_channels'] = in_channels | |
| ctc_decoder['out_channels'] = out_channels | |
| self.ctc_decoder = build_decoder(ctc_decoder) | |
| def forward(self, x, data=None): | |
| ctc_pred = self.ctc_decoder(x.detach() if self.detach else x, | |
| data=data) | |
| if self.training or self.infer_gtc: | |
| gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2), | |
| data=data) | |
| return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred} | |
| else: | |
| return ctc_pred | |
| class GTCDecoderTwo(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| gtc_decoder, | |
| ctc_decoder, | |
| infer_gtc=False, | |
| out_channels=0, | |
| **kwargs): | |
| super(GTCDecoderTwo, self).__init__() | |
| self.infer_gtc = infer_gtc | |
| gtc_decoder['out_channels'] = out_channels[0] | |
| ctc_decoder['out_channels'] = out_channels[1] | |
| gtc_decoder['in_channels'] = in_channels | |
| ctc_decoder['in_channels'] = in_channels | |
| self.gtc_decoder = build_decoder(gtc_decoder) | |
| self.ctc_decoder = build_decoder(ctc_decoder) | |
| def forward(self, x, data=None): | |
| x_ctc, x_gtc = x | |
| ctc_pred = self.ctc_decoder(x_ctc, data=data) | |
| if self.training or self.infer_gtc: | |
| gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2), | |
| data=data) | |
| return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred} | |
| else: | |
| return ctc_pred | |