Spaces:
Sleeping
Sleeping
| """Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583 | |
| """ | |
| import torch | |
| from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor | |
| from .utils import IntermediateLayerGetter | |
| from ..core import register | |
| class TimmModel(torch.nn.Module): | |
| def __init__(self, \ | |
| name, | |
| return_layers, | |
| pretrained=False, | |
| exportable=True, | |
| features_only=True, | |
| **kwargs) -> None: | |
| super().__init__() | |
| import timm | |
| model = timm.create_model( | |
| name, | |
| pretrained=pretrained, | |
| exportable=exportable, | |
| features_only=features_only, | |
| **kwargs | |
| ) | |
| # nodes, _ = get_graph_node_names(model) | |
| # print(nodes) | |
| # features = {'': ''} | |
| # model = create_feature_extractor(model, return_nodes=features) | |
| assert set(return_layers).issubset(model.feature_info.module_name()), \ | |
| f'return_layers should be a subset of {model.feature_info.module_name()}' | |
| # self.model = model | |
| self.model = IntermediateLayerGetter(model, return_layers) | |
| return_idx = [model.feature_info.module_name().index(name) for name in return_layers] | |
| self.strides = [model.feature_info.reduction()[i] for i in return_idx] | |
| self.channels = [model.feature_info.channels()[i] for i in return_idx] | |
| self.return_idx = return_idx | |
| self.return_layers = return_layers | |
| def forward(self, x: torch.Tensor): | |
| outputs = self.model(x) | |
| # outputs = [outputs[i] for i in self.return_idx] | |
| return outputs | |
| if __name__ == '__main__': | |
| model = TimmModel(name='resnet34', return_layers=['layer2', 'layer3']) | |
| data = torch.rand(1, 3, 640, 640) | |
| outputs = model(data) | |
| for output in outputs: | |
| print(output.shape) | |
| """ | |
| model: | |
| type: TimmModel | |
| name: resnet34 | |
| return_layers: ['layer2', 'layer4'] | |
| """ | |