Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # References: | |
| # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # DeiT: https://github.com/facebookresearch/deit | |
| # -------------------------------------------------------- | |
| from functools import partial | |
| import torch | |
| from torch._C import Value | |
| import torch.nn as nn | |
| import numpy as np | |
| from timm.models.vision_transformer import PatchEmbed, Block | |
| from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer | |
| from torch import einsum, nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FocalLoss(nn.CrossEntropyLoss): | |
| ''' Focal loss for classification tasks on imbalanced datasets ''' | |
| def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'): | |
| super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none') | |
| self.reduction = reduction | |
| self.gamma = gamma | |
| def forward(self, input_, target): | |
| cross_entropy = super().forward(input_, target) | |
| # Temporarily mask out ignore index to '0' for valid gather-indices input. | |
| # This won't contribute final loss as the cross_entropy contribution | |
| # for these would be zero. | |
| target = target * (target != self.ignore_index).long() | |
| input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1) | |
| loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy | |
| return torch.mean(loss) if self.reduction == 'mean' \ | |
| else torch.sum(loss) if self.reduction == 'sum' \ | |
| else loss | |
| # helper functions | |
| import math | |
| from functools import reduce | |
| def prob_mask_like(t, prob): | |
| return torch.zeros_like(t).float().uniform_(0, 1) < prob | |
| def mask_with_tokens(t, token_ids): | |
| init_no_mask = torch.full_like(t, False, dtype=torch.bool) | |
| mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) | |
| return mask | |
| def get_mask_subset_with_prob(mask, prob): | |
| batch, seq_len, device = *mask.shape, mask.device | |
| max_masked = math.ceil(prob * seq_len) | |
| num_tokens = mask.sum(dim=-1, keepdim=True) | |
| mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) | |
| mask_excess = mask_excess[:, :max_masked] | |
| rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) | |
| _, sampled_indices = rand.topk(max_masked, dim=-1) | |
| sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) | |
| new_mask = torch.zeros((batch, seq_len + 1), device=device) | |
| new_mask.scatter_(-1, sampled_indices, 1) | |
| return new_mask[:, 1:].bool() | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| # normalization | |
| # they use layernorm without bias, something that pytorch does not offer | |
| class LayerNorm(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(dim)) | |
| self.register_buffer("beta", torch.zeros(dim)) | |
| def forward(self, x): | |
| return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) | |
| # residual | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, *args, **kwargs): | |
| return self.fn(x, *args, **kwargs) + x | |
| # rotary positional embedding | |
| # https://arxiv.org/abs/2104.09864 | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, max_seq_len, *, device): | |
| seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) | |
| freqs = einsum("i , j -> i j", seq, self.inv_freq) | |
| return torch.cat((freqs, freqs), dim=-1) | |
| def rotate_half(x): | |
| x = rearrange(x, "... (j d) -> ... j d", j=2) | |
| x1, x2 = x.unbind(dim=-2) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(pos, t): | |
| return (t * pos.cos()) + (rotate_half(t) * pos.sin()) | |
| # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GELU for gating the feedforward | |
| # https://arxiv.org/abs/2002.05202 | |
| class SwiGLU(nn.Module): | |
| def forward(self, x): | |
| x, gate = x.chunk(2, dim=-1) | |
| return F.silu(gate) * x | |
| # parallel attention and feedforward with residual | |
| # discovered by Wang et al + EleutherAI from GPT-J fame | |
| class ParallelTransformerBlock(nn.Module): | |
| def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0): | |
| super().__init__() | |
| self.norm = LayerNorm(dim) | |
| attn_inner_dim = dim_head * heads | |
| ff_inner_dim = dim * ff_mult | |
| # import ipdb; ipdb.set_trace() | |
| self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) | |
| self.heads = heads | |
| self.scale = dim_head**-0.5 | |
| self.rotary_emb = RotaryEmbedding(dim_head) | |
| self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) | |
| self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) | |
| self.ff_out = nn.Sequential( | |
| SwiGLU(), | |
| nn.Linear(ff_inner_dim, dim, bias=False) | |
| ) | |
| self.attn_drop_rate = attn_drop_rate | |
| # for caching causal mask and rotary embeddings | |
| self.register_buffer("mask", None, persistent=False) | |
| self.register_buffer("pos_emb", None, persistent=False) | |
| def get_mask(self, n, device): | |
| if self.mask is not None and self.mask.shape[-1] >= n: | |
| return self.mask[:n, :n] | |
| mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) | |
| self.register_buffer("mask", mask, persistent=False) | |
| return mask | |
| def get_rotary_embedding(self, n, device): | |
| if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: | |
| return self.pos_emb[:n] | |
| pos_emb = self.rotary_emb(n, device=device) | |
| self.register_buffer("pos_emb", pos_emb, persistent=False) | |
| return pos_emb | |
| def forward(self, x, attn_mask=None): | |
| """ | |
| Performs self attention and feedforward | |
| einstein notation | |
| b - batch | |
| h - heads | |
| n, i, j - sequence length (base sequence length, source, target) | |
| d - feature dimension | |
| """ | |
| n, device, h = x.shape[1], x.device, self.heads | |
| # pre layernorm | |
| x = self.norm(x) | |
| # attention queries, keys, values, and feedforward inner | |
| q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) | |
| # split heads | |
| # they use multi-query single-key-value attention, yet another Noam Shazeer paper | |
| # they found no performance loss past a certain scale, and more efficient decoding obviously | |
| # https://arxiv.org/abs/1911.02150 | |
| q = rearrange(q, "b n (h d) -> b h n d", h=h) | |
| # rotary embeddings | |
| positions = self.get_rotary_embedding(n, device) | |
| q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) | |
| # scale | |
| q = q * self.scale | |
| # similarity | |
| sim = einsum("b h i d, b j d -> b h i j", q, k) | |
| # causal mask | |
| causal_mask = self.get_mask(n, device) | |
| sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | |
| # extra attention mask - for masking out attention from text CLS token to padding | |
| if exists(attn_mask): | |
| attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') | |
| sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) | |
| if self.attn_drop_rate != 0.: | |
| # import ipdb; ipdb.set_trace() | |
| drop_ind = sim != -torch.finfo(sim.dtype).max | |
| dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate | |
| sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) | |
| # attention | |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
| attn = sim.softmax(dim=-1) | |
| # aggregate values | |
| out = einsum("b h i j, b j d -> b h i d", attn, v) | |
| # merge heads | |
| out = rearrange(out, "b h n d -> b n (h d)") | |
| return self.attn_out(out) + self.ff_out(ff) | |
| # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| *, | |
| context_dim=None, | |
| dim_head=64, | |
| heads=8, | |
| parallel_ff=False, | |
| ff_mult=4, | |
| norm_context=False, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| inner_dim = heads * dim_head | |
| context_dim = default(context_dim, dim) | |
| self.norm = LayerNorm(dim) | |
| self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() | |
| self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) | |
| self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
| self.dropout = dropout | |
| # whether to have parallel feedforward | |
| ff_inner_dim = ff_mult * dim | |
| self.ff = nn.Sequential( | |
| nn.Linear(dim, ff_inner_dim * 2, bias=False), | |
| SwiGLU(), | |
| nn.Linear(ff_inner_dim, dim, bias=False) | |
| ) if parallel_ff else None | |
| def forward(self, x, context): | |
| """ | |
| Use text and query, and image as kv | |
| einstein notation | |
| b - batch | |
| h - heads | |
| n, i, j - sequence length (base sequence length, source, target) | |
| d - feature dimension | |
| """ | |
| # pre-layernorm, for queries and context | |
| x = self.norm(x) | |
| context = self.context_norm(context) | |
| # get queries | |
| q = self.to_q(x) | |
| q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) | |
| # scale | |
| q = q * self.scale | |
| # get key / values | |
| k, v = self.to_kv(context).chunk(2, dim=-1) | |
| # query / key similarity | |
| sim = einsum('b h i d, b j d -> b h i j', q, k) | |
| # dropout | |
| if self.training: | |
| dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout | |
| sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max) | |
| # attention | |
| sim = sim - sim.amax(dim=-1, keepdim=True) | |
| attn = sim.softmax(dim=-1) | |
| # aggregate | |
| out = einsum('b h i j, b j d -> b h i d', attn, v) | |
| # merge and combine heads | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| out = self.to_out(out) | |
| # add parallel feedforward (for multimodal layers) | |
| if exists(self.ff): | |
| out = out + self.ff(x) | |
| return out | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
| """ | |
| grid_size: int of the grid height and width | |
| return: | |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| grid_h = np.arange(grid_size, dtype=np.float32) | |
| grid_w = np.arange(grid_size, dtype=np.float32) | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size, grid_size]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token: | |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float32) | |
| omega /= embed_dim / 2. | |
| omega = 1. / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class MaskedAutoencoderViT(nn.Module): | |
| """ Masked Autoencoder with VisionTransformer backbone | |
| """ | |
| def __init__(self, img_size=224, patch_size=16, in_chans=3, | |
| embed_dim=1024, depth=24, num_heads=16, | |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
| mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True, | |
| unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8, | |
| ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0, | |
| less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1, | |
| use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False, | |
| use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False, | |
| cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70, | |
| projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.): | |
| super().__init__() | |
| # -------------------------------------------------------------------------- | |
| # MAE encoder specifics | |
| self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) | |
| num_patches = self.patch_embed.num_patches | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
| self.blocks = nn.ModuleList([ | |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
| for i in range(depth)]) | |
| self.norm = norm_layer(embed_dim) | |
| # -------------------------------------------------------------------------- | |
| # -------------------------------------------------------------------------- | |
| # MAE decoder specifics | |
| self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
| self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding | |
| self.mae_decoder_depth = decoder_depth | |
| self.mae_decoder_caption = mae_decoder_caption | |
| self.decoder_blocks = nn.ModuleList([ | |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
| for i in range(decoder_depth)]) | |
| if self.mae_decoder_caption: | |
| self.decoder_slot_layers = nn.ModuleList([]) | |
| for _ in range(decoder_slot_depth): | |
| self.decoder_slot_layers.append( | |
| Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)), | |
| # Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)) | |
| ) | |
| self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim) | |
| self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad | |
| self.decoder_norm = norm_layer(decoder_embed_dim) | |
| self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder | |
| # -------------------------------------------------------------------------- | |
| self.norm_pix_loss = norm_pix_loss | |
| # -------------------------------------------------------------------------- | |
| # captioner specifics | |
| # unimodal layer is for text tokens. | |
| # multimodal layer is for text to query from image. | |
| self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", ) | |
| # token embeddings | |
| # NOTE: +1 for mask token used by MLM objective | |
| # self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim) | |
| self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim) | |
| self.text_cls_token = nn.Parameter(torch.randn(uni_dim)) | |
| self.embed_dim = embed_dim | |
| self.uni_dim = uni_dim | |
| #import ipdb; ipdb.set_trace() | |
| # unimodal layers | |
| # TODO: search on the four parameters | |
| # uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4 | |
| self.text_drop_attn = text_drop_attn | |
| self.unimodal_layers = nn.ModuleList([]) | |
| for _ in range(unimodal_depth): | |
| self.unimodal_layers.append( | |
| Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head, | |
| heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)), | |
| ) | |
| self.need_uni_2_mul_proj = False | |
| if uni_dim != embed_dim: | |
| self.need_uni_2_mul_proj = True | |
| self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim) | |
| # multimodal layers | |
| self.multimodal_layers = nn.ModuleList([]) | |
| self.less_u = less_u | |
| if less_u: | |
| for _ in range(multimodal_depth): | |
| self.multimodal_layers.append(nn.ModuleList([ | |
| Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)), | |
| Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) | |
| ])) | |
| else: | |
| for _ in range(multimodal_depth): | |
| self.multimodal_layers.append(nn.ModuleList([ | |
| Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)), | |
| Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)) | |
| ])) | |
| # to logits: for softmax caption loss | |
| self.to_logits = nn.Sequential( | |
| LayerNorm(embed_dim), | |
| nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) | |
| ) | |
| self.ce_additional = ce_additional | |
| if ce_additional: | |
| # to logits: for other losses | |
| self.to_logits_1 = nn.Sequential( | |
| LayerNorm(embed_dim), | |
| nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False) | |
| ) | |
| nn.init.normal_(self.token_emb.weight, std=0.02) | |
| self.pad_id = 0 | |
| self.cls_id = 101 | |
| self.sep_id = 102 | |
| self.logsoftmax = nn.LogSoftmax(dim=1) | |
| self.extract_multi_level = extract_multi_level | |
| if self.extract_multi_level: | |
| self.projectors = nn.ModuleList([nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim // 2), | |
| nn.GELU(), | |
| nn.Linear(embed_dim // 2, embed_dim), | |
| norm_layer(embed_dim) | |
| ) for _ in [2, 5, 8,]]) | |
| # -------------------------------------------------------------------------- | |
| self.use_focal_loss = use_focal_loss | |
| self.use_weak_negative = use_weak_negative | |
| self.use_label_smooth = use_label_smooth | |
| self.ls_coef = ls_coef | |
| self.use_entropy = use_maximum_entropy | |
| self.use_word_weights = use_word_weights | |
| self.use_token_pos = use_token_pos | |
| self.predict_next_k_words = predict_next_k_words | |
| self.next_k = next_k | |
| self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1)) | |
| self.use_expect_k = use_expect_k | |
| self.use_top_k = use_top_k | |
| if self.use_word_weights or self.use_token_pos: | |
| self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none') | |
| else: | |
| self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean') | |
| self.masked_text = masked_text | |
| self.masked_text_ratio = masked_text_ratio | |
| # self.text_mask_token = nn.Parameter(torch.randn(embed_dim)) | |
| self.mask_token_id = len(self.tokenizer.vocab) | |
| self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False) | |
| self.text_length = text_length | |
| self.latent_projector_layer = projector_layer | |
| if self.latent_projector_layer != 0: | |
| self.latent_projector = [ | |
| nn.Linear(embed_dim, embed_dim), | |
| nn.ReLU() | |
| ] * (self.latent_projector_layer - 1) | |
| self.latent_projector.append(nn.Linear(embed_dim, embed_dim)) | |
| self.latent_projector = nn.Sequential(*self.latent_projector) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialization | |
| # initialize (and freeze) pos_embed by sin-cos embedding | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) | |
| self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | |
| # text_pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, ) | |
| # torch.nn.init.xavier_normal_(self.text_position_embed) # learnable text position embedding | |
| # initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
| w = self.patch_embed.proj.weight.data | |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.cls_token, std=.02) | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| # torch.nn.init.normal_(self.text_mask_token, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def patchify(self, imgs): | |
| """ | |
| imgs: (N, 3, H, W) | |
| x: (N, L, patch_size**2 *3) | |
| """ | |
| p = self.patch_embed.patch_size[0] | |
| assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 | |
| h = w = imgs.shape[2] // p | |
| x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) | |
| x = torch.einsum('nchpwq->nhwpqc', x) | |
| x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) | |
| return x | |
| def unpatchify(self, x): | |
| """ | |
| x: (N, L, patch_size**2 *3) | |
| imgs: (N, 3, H, W) | |
| """ | |
| p = self.patch_embed.patch_size[0] | |
| h = w = int(x.shape[1]**.5) | |
| assert h * w == x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) | |
| x = torch.einsum('nhwpqc->nchpwq', x) | |
| imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) | |
| return imgs | |
| def random_masking(self, x, mask_ratio): | |
| """ | |
| Perform per-sample random masking by per-sample shuffling. | |
| Per-sample shuffling is done by argsort random noise. | |
| x: [N, L, D], sequence | |
| """ | |
| N, L, D = x.shape # batch, length, dim | |
| len_keep = int(L * (1 - mask_ratio)) | |
| noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
| # sort noise for each sample | |
| ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| # keep the first subset | |
| ids_keep = ids_shuffle[:, :len_keep] | |
| x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
| # generate the binary mask: 0 is keep, 1 is remove | |
| mask = torch.ones([N, L], device=x.device) | |
| mask[:, :len_keep] = 0 | |
| # unshuffle to get the binary mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) | |
| return x_masked, mask, ids_restore, ids_keep | |
| def forward_encoder(self, x, mask_ratio): | |
| # embed patches | |
| x = self.patch_embed(x) | |
| # add pos embed w/o cls token | |
| x = x + self.pos_embed[:, 1:, :] | |
| # masking: length -> length * mask_ratio | |
| x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) | |
| # append cls token | |
| cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| if self.extract_multi_level: | |
| multi_level_feats = [] | |
| # apply Transformer blocks | |
| for blk_idx, blk in enumerate(self.blocks): | |
| x = blk(x) | |
| if blk_idx in [2, 5, 8]: | |
| multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x)) | |
| x = self.norm(x) | |
| multi_level_feats.append(x) | |
| return multi_level_feats, mask, ids_restore | |
| # apply Transformer blocks | |
| for blk_idx, blk in enumerate(self.blocks): | |
| x = blk(x) | |
| x = self.norm(x) | |
| return x, mask, ids_restore, ids_keep | |
| def forward_decoder(self, x, ids_restore): | |
| # embed tokens | |
| x = self.decoder_embed(x) | |
| # non_mask_token = x | |
| # append mask tokens to sequence | |
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
| x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle | |
| x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
| # add pos embed | |
| x = x + self.decoder_pos_embed | |
| # apply Transformer blocks | |
| decoder_feat = [] | |
| for idx, blk in enumerate(self.decoder_blocks): | |
| x = blk(x) | |
| if idx == self.mae_decoder_depth // 2: | |
| decoder_feat.append(x) | |
| x = self.decoder_norm(x) | |
| # use the output from decoder to do captioning | |
| # predictor projection | |
| x = self.decoder_pred(x) | |
| # remove cls token | |
| x = x[:, 1:, :] | |
| return x, decoder_feat | |
| def forward_loss(self, imgs, pred, mask): | |
| """ | |
| imgs: [N, 3, H, W] | |
| pred: [N, L, p*p*3] | |
| mask: [N, L], 0 is keep, 1 is remove, | |
| """ | |
| target = self.patchify(imgs) | |
| if self.norm_pix_loss: | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.e-6)**.5 | |
| loss = (pred - target) ** 2 | |
| loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
| loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
| return loss | |
| def embed_text(self, text): | |
| batch, device = text.shape[0], text.device | |
| seq = text.shape[1] | |
| text_tokens = self.token_emb(text) | |
| # append text cls tokens | |
| text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch) | |
| text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2) | |
| # create specific mask for text cls token at the end | |
| # to prevent it from attending to padding | |
| cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') | |
| attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
| # go through unimodal layers | |
| for attn_ff in self.unimodal_layers: | |
| text_tokens = attn_ff(text_tokens, attn_mask=attn_mask) | |
| if self.need_uni_2_mul_proj: | |
| text_tokens = self.uni_2_mul_proj(text_tokens) | |
| # get text cls token | |
| text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1] | |
| return text_tokens | |
| def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75, | |
| freeze_bert=False, teacher_forcing=False, caption_only=False, | |
| encoder_only=False, word_weights=None, syn_count=None): | |
| latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio) | |
| if not caption_only: | |
| pred, decoder_feat = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] | |
| mae_loss = self.forward_loss(imgs, pred, mask) | |
| else: | |
| mae_loss = 0. | |
| if self.latent_projector_layer != 0: | |
| latent = self.latent_projector(latent) | |
| # latent: visual info: N, L, C | |
| # caption_ids: N, Len | |
| text, labels = caption_ids[:, :-1], caption_ids[:, 1:] | |
| seq = text.shape[1] | |
| text_tokens = self.embed_text(text) # N, Len, C | |
| # create specific mask for text cls token at the end | |
| # to prevent it from attending to padding | |
| cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j') | |
| attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
| unimodal_text_tokens = text_tokens | |
| if not self.less_u: | |
| for attn_ff, cross_attn in self.multimodal_layers: | |
| text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1]) | |
| text_tokens = cross_attn(text_tokens, latent) | |
| else: | |
| # dim, num_head, | |
| for cross_attn1, cross_attn2 in self.multimodal_layers: | |
| text_tokens = cross_attn1(text_tokens, latent) | |
| text_tokens = cross_attn2(text_tokens, latent) | |
| logits = self.to_logits(text_tokens) # N, Len, NVocab | |
| logits = logits.reshape(-1, len(self.tokenizer.vocab)) | |
| labels = labels.reshape(-1) | |
| caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,) | |
| return mae_loss, caption_loss, None | |
| def mae_vit_small_patch16_dec512d8b(**kwargs): | |
| model = MaskedAutoencoderViT( | |
| patch_size=16, embed_dim=384, depth=12, num_heads=6, | |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mae_vit_base_patch16_dec512d8b(**kwargs): | |
| model = MaskedAutoencoderViT( | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, | |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mae_vit_large_patch16_dec512d8b(**kwargs): | |
| model = MaskedAutoencoderViT( | |
| patch_size=16, embed_dim=1024, depth=24, num_heads=16, | |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| def mae_vit_huge_patch14_dec512d8b(**kwargs): | |
| model = MaskedAutoencoderViT( | |
| patch_size=14, embed_dim=1280, depth=32, num_heads=16, | |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| return model | |
| # set recommended archs | |
| mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b | |
| mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
| mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
| mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks | |