SegformerPlusPlus / hubconf.py
KieDani's picture
Initial upload
60465e5 verified
import sys
import os
import torch
import torchvision.transforms as T
from typing import List, Tuple
# Import the necessary functions for custom download
from torch.hub import download_url_to_file
import urllib.parse
# --- WICHTIG: TorchHub Dependencies ---
# These are informational only. Users must install these packages.
dependencies = [
'tomesd',
'omegaconf',
'numpy',
'rich',
'yapf',
'addict',
'tqdm',
'packaging',
'torchvision'
]
# Adds the path to the 'model_without_OpenMMLab' subdirectory to the sys.path list.
model_dir = os.path.join(os.path.dirname(__file__), 'model_without_OpenMMLab')
sys.path.insert(0, model_dir)
# Imports all entry points from the subdirectory.
from segformer_plusplus.build_model import create_model
from segformer_plusplus.random_benchmark import random_benchmark
# Removes the added path again to keep the sys.path list clean.
sys.path.pop(0)
def _get_local_cache_path(url: str, filename: str) -> str:
"""
Creates the full local path to the checkpoint file in the PyTorch Hub cache.
"""
# Retrieves the root folder of the PyTorch Hub cache (~/.cache/torch/)
torch_home = torch.hub._get_torch_home()
# The default checkpoint directory
checkpoint_dir = os.path.join(torch_home, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# Adds a hash component for the URL to ensure uniqueness,
# as the URL itself does not contain a unique file name.
# We use the URL path as part of the hash.
url_path_hash = urllib.parse.quote_plus(url)
# The final local file name, including the base name + URL hash.
local_filename = f"{filename}_{url_path_hash[:10]}.pt"
return os.path.join(checkpoint_dir, local_filename)
# --- ENTRYPOINT 1: Main Model (ADJUSTED) ---
def segformer_plusplus(
backbone: str = 'b5',
tome_strategy: str = 'bsm_hq',
out_channels: int = 19,
pretrained: bool = True,
checkpoint_url: str = None,
**kwargs
) -> torch.nn.Module:
"""
Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation.
Loads a SegFormer++ model with the specified backbone and head architecture.
Install requirements via:
pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision
Args:
backbone (str): The backbone type. Selectable from: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5'].
tome_strategy (str): The token merging strategy. Selectable from: ['bsm_hq', 'bsm_fast', 'n2d_2x2'].
out_channels (int): Number of output classes (e.g., 19 for Cityscapes).
pretrained (bool): Whether to load the ImageNet pre-trained weights.
checkpoint_url (str, optional): A URL to a specific checkpoint.
**Important:** The download uses torch.hub.download_url_to_file(),
which may be required for non-direct links.
Returns:
torch.nn.Module: The instantiated SegFormer++ model.
"""
model = create_model(
backbone=backbone,
tome_strategy=tome_strategy,
out_channels=out_channels,
pretrained=pretrained
)
if checkpoint_url:
# Generate a unique file path in the PyTorch cache
# We use the backbone name as part of the file name
local_filepath = _get_local_cache_path(
url=checkpoint_url,
filename=f"segformer_plusplus_{backbone}"
)
print(f"Attempting to load checkpoint from {checkpoint_url}...")
if not os.path.exists(local_filepath):
# Use download_url_to_file for the non-direct download
try:
print(f"File not in cache. Downloading to {local_filepath}...")
# This replaces load_state_dict_from_url and saves the file in the cache
download_url_to_file(
checkpoint_url,
local_filepath,
progress=True
)
print("Download successful.")
except Exception as e:
print(f"Error downloading checkpoint from {checkpoint_url}. Check the URL or use a direct asset link. Error: {e}")
# If the download fails, we return an un-loaded model
return model
# Load the state dictionary from the downloaded file
try:
state_dict = torch.load(local_filepath, map_location='cpu')
# Perform state_dict cleanup here if necessary,
# e.g., if the state is nested under a 'model' key
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict, strict=True)
print("Checkpoint loaded successfully.")
except Exception as e:
print(f"Error loading state dict from file {local_filepath}: {e}")
# Again, return the un-loaded/ImageNet pre-trained model
print("The model was instantiated, but the checkpoint could not be loaded.")
return model
# --- ENTRYPOINT 2: Data Processing ---
def data_transforms(
resolution: Tuple[int, int] = (1024, 1024),
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225],
) -> T.Compose:
"""
Provides the appropriate data transformations for a given dataset.
This function is an entry point to get the necessary preprocessing steps
for images based on typical ImageNet values.
Args:
resolution (Tuple[int, int]): The desired size for the images (width, height).
Defaults to (1024, 1024).
mean (List[float]): The mean values for normalization. Defaults to the
ImageNet means.
std (List[float]): The standard deviations for normalization. Defaults to the
ImageNet standard deviations.
Returns:
torchvision.transforms.Compose: A composition of transforms
that can be applied to input images.
Example:
>>> # Load transforms with default parameters
>>> transform = torch.hub.load('user/repo_name', 'data_transforms')
>>>
>>> # Load transforms with resize to custom image resolution and default normalization
>>> transform_small = torch.hub.load('user/repo_name', 'data_transforms', resolution=(512, 512))
"""
transform = T.Compose([
T.Resize(resolution),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
return transform
# --- ENTRYPOINT 3: Random Benchmark ---
def random_benchmark_entrypoint(**kwargs):
"""
Runs a random benchmark for SegFormer++.
"""
return random_benchmark(**kwargs)