The following provides minimal code for loading and exporting the Tessera geospatial foundation model. The original checkpoint file best_model_fsdp_20250427_084307.pt hosted on Google Drive was ~7GB however repackaging only the model weights results in a 350MB checkpoint file model.pt. Further, the model is also exported with torch.export to model_exported_program.pt2 so that the model code itself is not needed to run inference.

import torch


class AttentionPooling(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.query = torch.nn.Linear(input_dim, 1)

    def forward(self, x):
        # x: (B, seq_len, dim)
        w = torch.softmax(self.query(x), dim=1)  # (B, seq_len, 1)
        return (w * x).sum(dim=1)


class TemporalAwarePooling(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.query = torch.nn.Linear(input_dim, 1)
        self.temporal_context = torch.nn.GRU(input_dim, input_dim, batch_first=True)

    def forward(self, x):
        # First capture temporal context through RNN
        x_context, _ = self.temporal_context(x)
        # Then calculate attention weights
        w = torch.softmax(self.query(x_context), dim=1)
        return (w * x).sum(dim=1)


class TemporalEncoding(torch.nn.Module):
    def __init__(self, d_model, num_freqs=64):
        super().__init__()
        self.num_freqs = num_freqs
        self.d_model = d_model

        # Learnable frequency parameters (more flexible than fixed frequencies)
        self.freqs = torch.nn.Parameter(torch.exp(torch.linspace(0, np.log(365.0), num_freqs)))

        # Project Fourier features to the target dimension through a linear layer
        self.proj = torch.nn.Linear(2 * num_freqs, d_model)
        self.phase = torch.nn.Parameter(torch.zeros(1, 1, d_model))  # Learnable phase offset

    def forward(self, doy):
        # doy: (B, seq_len, 1)
        t = doy / 365.0 * 2 * torch.pi  # Normalize to the 0-2ฯ€ range

        # Generate multi-frequency sine/cosine features
        t_scaled = t * self.freqs.view(1, 1, -1)  # (B, seq_len, num_freqs)
        sin = torch.sin(t_scaled + self.phase[..., :self.num_freqs])
        cos = torch.cos(t_scaled + self.phase[..., self.num_freqs:2*self.num_freqs])

        # Concatenate and project to the target dimension
        encoding = torch.cat([sin, cos], dim=-1)  # (B, seq_len, 2*num_freqs)
        return self.proj(encoding)  # (B, seq_len, d_model)


class TemporalPositionalEncoder(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, doy):
        # doy: [B, T] tensor containing DOY values (0-365)
        position = doy.unsqueeze(-1).float()  # Ensure float type
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float) * -(torch.log(torch.tensor(10000.0)) / self.d_model))
        div_term = div_term.to(doy.device)

        pe = torch.zeros(doy.shape[0], doy.shape[1], self.d_model, device=doy.device)
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe


class TransformerEncoder(torch.nn.Module):
    def __init__(self, band_num, latent_dim, nhead=8, num_encoder_layers=4,
                dim_feedforward=512, dropout=0.1, max_seq_len=20):
        super().__init__()
        # Total input dimension: bands
        input_dim = band_num

        # Embedding to increase dimension
        self.embedding = torch.nn.Sequential(
            torch.nn.Linear(input_dim, latent_dim*4),
            torch.nn.ReLU(),
            torch.nn.Linear(latent_dim*4, latent_dim*4)
        )

        # Temporal Encoder for DOY as position encoding
        self.temporal_encoder = TemporalPositionalEncoder(d_model=latent_dim*4)

        # Transformer Encoder Layer
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=latent_dim*4,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="relu",
            batch_first=True,
        )
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Temporal Aware Pooling
        self.attn_pool = TemporalAwarePooling(latent_dim*4)

    def forward(self, x):
        # x: (B, seq_len, 10 bands + 1 doy)
        # Split bands and doy
        bands = x[:, :, :-1]  # All columns except last one
        doy = x[:, :, -1]     # Last column is DOY
        # Embedding of bands
        bands_embedded = self.embedding(bands)  # (B, seq_len, latent_dim*4)
        temporal_encoding = self.temporal_encoder(doy)
        # Add temporal encoding to embedded bands (instead of random positional encoding)
        x = bands_embedded + temporal_encoding
        x = self.transformer_encoder(x)
        x = self.attn_pool(x)
        return x


class Tessera(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.s2_backbone = TransformerEncoder(
            band_num=10,
            latent_dim=128,
            nhead=8,
            num_encoder_layers=8,
            dim_feedforward=4096,
            dropout=0.1,
            max_seq_len=40
        )
        self.s1_backbone = TransformerEncoder(
            band_num=2,
            latent_dim=128,
            nhead=8,
            num_encoder_layers=8,
            dim_feedforward=4096,
            dropout=0.1,
            max_seq_len=40
        )
        self.dim_reducer = torch.nn.Sequential(torch.nn.Linear(128 * 8, 128))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: tensor of shape(b, t, c) where c=14, the first 11 channels are 
                sentinel-2 (10 bands + 1 doy features) and the last 3 channels are 
                sentinel-1 (2 bands + 1 doy features)
        """
        assert x.shape[-1] == 14
        s2_x, s1_x = x[..., :11], x[..., 11:]
        s2_feat = self.s2_backbone(s2_x)  # (b, d)
        s1_feat = self.s1_backbone(s1_x)  # (b, d)
        fused = torch.cat([s2_feat, s1_feat], dim=-1)  # (b, 2d)
        fused = self.dim_reducer(fused)  # (b, 128)
        return fused


# Load the pretrained model for inference only without the projection using the pretrained config
model = Tessera()
model.eval()

b, t = 2, 10
s2 = torch.randn(b, t, 10)
s2_doy = torch.randint(1, 365, (b, t, 1))
s1 = torch.randn(b, t, 2)
s1_doy = torch.randint(1, 365, (b, t, 1))

x = torch.cat([s2, s2_doy, s1, s1_doy], dim=-1)
print(model(x).shape)

# Load and extract only the model state dict then save to model.pt
path = "best_model_fsdp_20250427_084307.pt"
ckpt = torch.load(path, map_location="cpu", weights_only=False)
modules = ["s2_backbone", "s1_backbone", "dim_reducer"]
state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model_state_dict"].items()}
state_dict = {k: v for k, v in state_dict.items() if k.split(".")[0] in modules}
model.load_state_dict(state_dict, strict=True)
torch.save(model.state_dict(), "model.pt")

# Export the model and save to model_exported_program.pt2
from torch.export.dynamic_shapes import Dim

example_inputs = torch.randn(1, 10, 14)
dims = (Dim.AUTO, Dim.AUTO, 14)
model_program = torch.export.export(mod=model, args=(example_inputs,), dynamic_shapes={"x": dims})
torch.export.save(model_program, 'model_exported_program.pt2')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support