File size: 4,496 Bytes
ae522ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import segmentation_models_pytorch as smp
from mwm import logger

# Utils for model architecture
def freeze_encoder(model, freeze_encoder_layers):
    """
    Freeze the encoder layers of the model.
    """
    if not freeze_encoder_layers: # list is empty
        logger.info("No Encoder layer is frozen. All layers are trainable.")
    elif freeze_encoder_layers[0] == "all": # freeze all layers
        for param in model.encoder.parameters():
            param.requires_grad = False
        logger.info("All Encoder layers are frozen.")
    else:
        for name, child in model.encoder.named_children():
            if name in freeze_encoder_layers:
                for param in child.parameters():
                    param.requires_grad = False
        logger.info(f"Encoder layers: {freeze_encoder_layers} are selectively frozen.")


def make_model(network_name, encoder_weights):
    if network_name == "unet_resnet34_2ch":
        model = smp.Unet(encoder_name="resnet34", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_resnet101_2ch":
        model = smp.Unet(encoder_name="resnet101", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_resnet152_2ch":
        model = smp.Unet(encoder_name="resnet152", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientnet_b0_2ch":
        model = smp.Unet(encoder_name="efficientnet-b0", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientnet_b2_2ch":
        model = smp.Unet(encoder_name="efficientnet-b2", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientnet_b3_2ch":
        model = smp.Unet(encoder_name="efficientnet-b3", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientnet_b4_2ch":
        model = smp.Unet(encoder_name="efficientnet-b4", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_dpn92_2ch":
        model = smp.Unet(encoder_name="dpn92", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_inceptionresnetv2_2ch":
        model = smp.Unet(encoder_name="inceptionresnetv2", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_regnety_040_2ch":
        model = smp.Unet(encoder_name="tu-regnety_040", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_convnext_base_2ch":
        model = smp.Unet(encoder_name="tu-convnext_base", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_convnextv2_base_2ch":
        model = smp.Unet(encoder_name="tu-convnextv2_base", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_swinv2_small_window8_256_2ch":
        model = smp.Unet(encoder_name="tu-swinv2_small_window8_256", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_edgenext_small_2ch":
        model = smp.Unet(encoder_name="tu-edgenext_small", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_edgenext_base_2ch":
        model = smp.Unet(encoder_name="tu-edgenext_base", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientformerv2_s1_2ch":
        model = smp.Unet(encoder_name="tu-efficientformerv2_s1", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    elif network_name == "unet_efficientformerv2_l_2ch":
        model = smp.Unet(encoder_name="tu-efficientformerv2_l", encoder_weights=encoder_weights, in_channels=3, classes=2, activation="sigmoid")
    else:
        logger.error(f"Invalid network: {network_name}")
        raise ValueError(f"Invalid network: {network_name}")

    logger.info(f"Model: {network_name} successfully created. ")
    print("Named children in Encoder:")
    for name, _ in list(model.encoder.named_children()):
        print(name)
    return model

# Classes for customized model architecture