File size: 7,027 Bytes
60465e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import sys
import os
import torch
import torchvision.transforms as T
from typing import List, Tuple
# Import the necessary functions for custom download
from torch.hub import download_url_to_file
import urllib.parse

# --- WICHTIG: TorchHub Dependencies ---
# These are informational only. Users must install these packages.
dependencies = [
    'tomesd',
    'omegaconf',
    'numpy',
    'rich',
    'yapf',
    'addict',
    'tqdm',
    'packaging',
    'torchvision'
]

# Adds the path to the 'model_without_OpenMMLab' subdirectory to the sys.path list.
model_dir = os.path.join(os.path.dirname(__file__), 'model_without_OpenMMLab')
sys.path.insert(0, model_dir)

# Imports all entry points from the subdirectory.
from segformer_plusplus.build_model import create_model
from segformer_plusplus.random_benchmark import random_benchmark

# Removes the added path again to keep the sys.path list clean.
sys.path.pop(0)


def _get_local_cache_path(url: str, filename: str) -> str:
    """

    Creates the full local path to the checkpoint file in the PyTorch Hub cache.

    """
    # Retrieves the root folder of the PyTorch Hub cache (~/.cache/torch/)
    torch_home = torch.hub._get_torch_home()

    # The default checkpoint directory
    checkpoint_dir = os.path.join(torch_home, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Adds a hash component for the URL to ensure uniqueness,
    # as the URL itself does not contain a unique file name.
    # We use the URL path as part of the hash.
    url_path_hash = urllib.parse.quote_plus(url)

    # The final local file name, including the base name + URL hash.
    local_filename = f"{filename}_{url_path_hash[:10]}.pt"

    return os.path.join(checkpoint_dir, local_filename)


# --- ENTRYPOINT 1: Main Model (ADJUSTED) ---
def segformer_plusplus(

        backbone: str = 'b5',

        tome_strategy: str = 'bsm_hq',

        out_channels: int = 19,

        pretrained: bool = True,

        checkpoint_url: str = None,

        **kwargs

) -> torch.nn.Module:
    """

    Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation.



    Loads a SegFormer++ model with the specified backbone and head architecture.

    Install requirements via:

        pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision



    Args:

        backbone (str): The backbone type. Selectable from: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5'].

        tome_strategy (str): The token merging strategy. Selectable from: ['bsm_hq', 'bsm_fast', 'n2d_2x2'].

        out_channels (int): Number of output classes (e.g., 19 for Cityscapes).

        pretrained (bool): Whether to load the ImageNet pre-trained weights.

        checkpoint_url (str, optional): A URL to a specific checkpoint.

                                        **Important:** The download uses torch.hub.download_url_to_file(),

                                        which may be required for non-direct links.



    Returns:

        torch.nn.Module: The instantiated SegFormer++ model.

    """
    model = create_model(
        backbone=backbone,
        tome_strategy=tome_strategy,
        out_channels=out_channels,
        pretrained=pretrained
    )

    if checkpoint_url:
        # Generate a unique file path in the PyTorch cache
        # We use the backbone name as part of the file name
        local_filepath = _get_local_cache_path(
            url=checkpoint_url,
            filename=f"segformer_plusplus_{backbone}"
        )

        print(f"Attempting to load checkpoint from {checkpoint_url}...")

        if not os.path.exists(local_filepath):
            # Use download_url_to_file for the non-direct download
            try:
                print(f"File not in cache. Downloading to {local_filepath}...")

                # This replaces load_state_dict_from_url and saves the file in the cache
                download_url_to_file(
                    checkpoint_url,
                    local_filepath,
                    progress=True
                )
                print("Download successful.")

            except Exception as e:
                print(f"Error downloading checkpoint from {checkpoint_url}. Check the URL or use a direct asset link. Error: {e}")
                # If the download fails, we return an un-loaded model
                return model

        # Load the state dictionary from the downloaded file
        try:
            state_dict = torch.load(local_filepath, map_location='cpu')

            # Perform state_dict cleanup here if necessary,
            # e.g., if the state is nested under a 'model' key
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']

            model.load_state_dict(state_dict, strict=True)
            print("Checkpoint loaded successfully.")

        except Exception as e:
            print(f"Error loading state dict from file {local_filepath}: {e}")
            # Again, return the un-loaded/ImageNet pre-trained model
            print("The model was instantiated, but the checkpoint could not be loaded.")

    return model


# --- ENTRYPOINT 2: Data Processing ---
def data_transforms(

        resolution: Tuple[int, int] = (1024, 1024),

        mean: List[float] = [0.485, 0.456, 0.406],

        std: List[float] = [0.229, 0.224, 0.225],

) -> T.Compose:
    """

    Provides the appropriate data transformations for a given dataset.



    This function is an entry point to get the necessary preprocessing steps

    for images based on typical ImageNet values.



    Args:

        resolution (Tuple[int, int]): The desired size for the images (width, height).

                                     Defaults to (1024, 1024).

        mean (List[float]): The mean values for normalization. Defaults to the

                             ImageNet means.

        std (List[float]): The standard deviations for normalization. Defaults to the

                           ImageNet standard deviations.



    Returns:

        torchvision.transforms.Compose: A composition of transforms

                                        that can be applied to input images.



    Example:

        >>> # Load transforms with default parameters

        >>> transform = torch.hub.load('user/repo_name', 'data_transforms')

        >>>

        >>> # Load transforms with resize to custom image resolution and default normalization

        >>> transform_small = torch.hub.load('user/repo_name', 'data_transforms', resolution=(512, 512))

    """
    transform = T.Compose([
        T.Resize(resolution),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])
    return transform


# --- ENTRYPOINT 3: Random Benchmark ---
def random_benchmark_entrypoint(**kwargs):
    """

    Runs a random benchmark for SegFormer++.

    """
    return random_benchmark(**kwargs)