Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| class MTB(nn.Module): | |
| def __init__(self, cnn_num, in_channels): | |
| super(MTB, self).__init__() | |
| self.block = nn.Sequential() | |
| self.out_channels = in_channels | |
| self.cnn_num = cnn_num | |
| if self.cnn_num == 2: | |
| for i in range(self.cnn_num): | |
| self.block.add_module( | |
| 'conv_{}'.format(i), | |
| nn.Conv2d( | |
| in_channels=in_channels if i == 0 else 32 * | |
| (2**(i - 1)), | |
| out_channels=32 * (2**i), | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ), | |
| ) | |
| self.block.add_module('relu_{}'.format(i), nn.ReLU()) | |
| self.block.add_module('bn_{}'.format(i), | |
| nn.BatchNorm2d(32 * (2**i))) | |
| def forward(self, images): | |
| x = self.block(images) | |
| if self.cnn_num == 2: | |
| # (b, w, h, c) | |
| x = x.permute(0, 3, 2, 1) | |
| x_shape = x.shape | |
| x = torch.reshape( | |
| x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3])) | |
| return x | |