lucid-hf's picture
CI: deploy Docker/PDM Space
98a3af2 verified
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
"""
import torch
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from .utils import IntermediateLayerGetter
from ..core import register
@register()
class TimmModel(torch.nn.Module):
def __init__(self, \
name,
return_layers,
pretrained=False,
exportable=True,
features_only=True,
**kwargs) -> None:
super().__init__()
import timm
model = timm.create_model(
name,
pretrained=pretrained,
exportable=exportable,
features_only=features_only,
**kwargs
)
# nodes, _ = get_graph_node_names(model)
# print(nodes)
# features = {'': ''}
# model = create_feature_extractor(model, return_nodes=features)
assert set(return_layers).issubset(model.feature_info.module_name()), \
f'return_layers should be a subset of {model.feature_info.module_name()}'
# self.model = model
self.model = IntermediateLayerGetter(model, return_layers)
return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
self.strides = [model.feature_info.reduction()[i] for i in return_idx]
self.channels = [model.feature_info.channels()[i] for i in return_idx]
self.return_idx = return_idx
self.return_layers = return_layers
def forward(self, x: torch.Tensor):
outputs = self.model(x)
# outputs = [outputs[i] for i in self.return_idx]
return outputs
if __name__ == '__main__':
model = TimmModel(name='resnet34', return_layers=['layer2', 'layer3'])
data = torch.rand(1, 3, 640, 640)
outputs = model(data)
for output in outputs:
print(output.shape)
"""
model:
type: TimmModel
name: resnet34
return_layers: ['layer2', 'layer4']
"""