File size: 7,027 Bytes
60465e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import sys
import os
import torch
import torchvision.transforms as T
from typing import List, Tuple
# Import the necessary functions for custom download
from torch.hub import download_url_to_file
import urllib.parse
# --- WICHTIG: TorchHub Dependencies ---
# These are informational only. Users must install these packages.
dependencies = [
'tomesd',
'omegaconf',
'numpy',
'rich',
'yapf',
'addict',
'tqdm',
'packaging',
'torchvision'
]
# Adds the path to the 'model_without_OpenMMLab' subdirectory to the sys.path list.
model_dir = os.path.join(os.path.dirname(__file__), 'model_without_OpenMMLab')
sys.path.insert(0, model_dir)
# Imports all entry points from the subdirectory.
from segformer_plusplus.build_model import create_model
from segformer_plusplus.random_benchmark import random_benchmark
# Removes the added path again to keep the sys.path list clean.
sys.path.pop(0)
def _get_local_cache_path(url: str, filename: str) -> str:
"""
Creates the full local path to the checkpoint file in the PyTorch Hub cache.
"""
# Retrieves the root folder of the PyTorch Hub cache (~/.cache/torch/)
torch_home = torch.hub._get_torch_home()
# The default checkpoint directory
checkpoint_dir = os.path.join(torch_home, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# Adds a hash component for the URL to ensure uniqueness,
# as the URL itself does not contain a unique file name.
# We use the URL path as part of the hash.
url_path_hash = urllib.parse.quote_plus(url)
# The final local file name, including the base name + URL hash.
local_filename = f"{filename}_{url_path_hash[:10]}.pt"
return os.path.join(checkpoint_dir, local_filename)
# --- ENTRYPOINT 1: Main Model (ADJUSTED) ---
def segformer_plusplus(
backbone: str = 'b5',
tome_strategy: str = 'bsm_hq',
out_channels: int = 19,
pretrained: bool = True,
checkpoint_url: str = None,
**kwargs
) -> torch.nn.Module:
"""
Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation.
Loads a SegFormer++ model with the specified backbone and head architecture.
Install requirements via:
pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision
Args:
backbone (str): The backbone type. Selectable from: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5'].
tome_strategy (str): The token merging strategy. Selectable from: ['bsm_hq', 'bsm_fast', 'n2d_2x2'].
out_channels (int): Number of output classes (e.g., 19 for Cityscapes).
pretrained (bool): Whether to load the ImageNet pre-trained weights.
checkpoint_url (str, optional): A URL to a specific checkpoint.
**Important:** The download uses torch.hub.download_url_to_file(),
which may be required for non-direct links.
Returns:
torch.nn.Module: The instantiated SegFormer++ model.
"""
model = create_model(
backbone=backbone,
tome_strategy=tome_strategy,
out_channels=out_channels,
pretrained=pretrained
)
if checkpoint_url:
# Generate a unique file path in the PyTorch cache
# We use the backbone name as part of the file name
local_filepath = _get_local_cache_path(
url=checkpoint_url,
filename=f"segformer_plusplus_{backbone}"
)
print(f"Attempting to load checkpoint from {checkpoint_url}...")
if not os.path.exists(local_filepath):
# Use download_url_to_file for the non-direct download
try:
print(f"File not in cache. Downloading to {local_filepath}...")
# This replaces load_state_dict_from_url and saves the file in the cache
download_url_to_file(
checkpoint_url,
local_filepath,
progress=True
)
print("Download successful.")
except Exception as e:
print(f"Error downloading checkpoint from {checkpoint_url}. Check the URL or use a direct asset link. Error: {e}")
# If the download fails, we return an un-loaded model
return model
# Load the state dictionary from the downloaded file
try:
state_dict = torch.load(local_filepath, map_location='cpu')
# Perform state_dict cleanup here if necessary,
# e.g., if the state is nested under a 'model' key
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict, strict=True)
print("Checkpoint loaded successfully.")
except Exception as e:
print(f"Error loading state dict from file {local_filepath}: {e}")
# Again, return the un-loaded/ImageNet pre-trained model
print("The model was instantiated, but the checkpoint could not be loaded.")
return model
# --- ENTRYPOINT 2: Data Processing ---
def data_transforms(
resolution: Tuple[int, int] = (1024, 1024),
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225],
) -> T.Compose:
"""
Provides the appropriate data transformations for a given dataset.
This function is an entry point to get the necessary preprocessing steps
for images based on typical ImageNet values.
Args:
resolution (Tuple[int, int]): The desired size for the images (width, height).
Defaults to (1024, 1024).
mean (List[float]): The mean values for normalization. Defaults to the
ImageNet means.
std (List[float]): The standard deviations for normalization. Defaults to the
ImageNet standard deviations.
Returns:
torchvision.transforms.Compose: A composition of transforms
that can be applied to input images.
Example:
>>> # Load transforms with default parameters
>>> transform = torch.hub.load('user/repo_name', 'data_transforms')
>>>
>>> # Load transforms with resize to custom image resolution and default normalization
>>> transform_small = torch.hub.load('user/repo_name', 'data_transforms', resolution=(512, 512))
"""
transform = T.Compose([
T.Resize(resolution),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
return transform
# --- ENTRYPOINT 3: Random Benchmark ---
def random_benchmark_entrypoint(**kwargs):
"""
Runs a random benchmark for SegFormer++.
"""
return random_benchmark(**kwargs) |