Spaces:
Runtime error
Runtime error
| import math | |
| import types | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Compose, Resize, InterpolationMode | |
| import open_clip | |
| from open_clip.transformer import VisionTransformer | |
| from open_clip.timm_model import TimmModel | |
| from einops import rearrange | |
| from .utils import ( | |
| hooked_attention_timm_forward, | |
| hooked_resblock_forward, | |
| hooked_attention_forward, | |
| hooked_resblock_timm_forward, | |
| hooked_attentional_pooler_timm_forward, | |
| vit_dynamic_size_forward, | |
| min_max, | |
| hooked_torch_multi_head_attention_forward, | |
| ) | |
| class LeWrapper(nn.Module): | |
| """ | |
| Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model. | |
| """ | |
| def __init__(self, model, layer_index=-2): | |
| super(LeWrapper, self).__init__() | |
| # ------------ copy of model's attributes and methods ------------ | |
| for attr in dir(model): | |
| if not attr.startswith("__"): | |
| setattr(self, attr, getattr(model, attr)) | |
| # ------------ activate hooks & gradient ------------ | |
| self._activate_hooks(layer_index=layer_index) | |
| def _activate_hooks(self, layer_index): | |
| # ------------ identify model's type ------------ | |
| print("Activating necessary hooks and gradients ....") | |
| if isinstance(self.visual, VisionTransformer): | |
| # --- Activate dynamic image size --- | |
| self.visual.forward = types.MethodType( | |
| vit_dynamic_size_forward, self.visual | |
| ) | |
| # Get patch size | |
| self.patch_size = self.visual.patch_size[0] | |
| # Get starting depth (in case of negative layer_index) | |
| self.starting_depth = ( | |
| layer_index | |
| if layer_index >= 0 | |
| else len(self.visual.transformer.resblocks) + layer_index | |
| ) | |
| if self.visual.attn_pool is None: | |
| self.model_type = "clip" | |
| self._activate_self_attention_hooks() | |
| else: | |
| self.model_type = "coca" | |
| self._activate_att_pool_hooks(layer_index=layer_index) | |
| elif isinstance(self.visual, TimmModel): | |
| # --- Activate dynamic image size --- | |
| self.visual.trunk.dynamic_img_size = True | |
| self.visual.trunk.patch_embed.dynamic_img_size = True | |
| self.visual.trunk.patch_embed.strict_img_size = False | |
| self.visual.trunk.patch_embed.flatten = False | |
| self.visual.trunk.patch_embed.output_fmt = "NHWC" | |
| self.model_type = "timm_siglip" | |
| # --- Get patch size --- | |
| self.patch_size = self.visual.trunk.patch_embed.patch_size[0] | |
| # --- Get starting depth (in case of negative layer_index) --- | |
| self.starting_depth = ( | |
| layer_index | |
| if layer_index >= 0 | |
| else len(self.visual.trunk.blocks) + layer_index | |
| ) | |
| if ( | |
| hasattr(self.visual.trunk, "attn_pool") | |
| and self.visual.trunk.attn_pool is not None | |
| ): | |
| self._activate_timm_attn_pool_hooks(layer_index=layer_index) | |
| else: | |
| self._activate_timm_self_attention_hooks() | |
| else: | |
| raise ValueError( | |
| "Model currently not supported, see legrad.list_pretrained() for a list of available models" | |
| ) | |
| print("Hooks and gradients activated!") | |
| def _activate_self_attention_hooks(self): | |
| # Adjusting to use the correct structure | |
| if isinstance(self.visual, VisionTransformer): | |
| blocks = self.visual.transformer.resblocks | |
| elif isinstance(self.visual, TimmModel): | |
| blocks = self.visual.trunk.blocks | |
| else: | |
| raise ValueError("Unsupported model type for self-attention hooks") | |
| # ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
| # Necessary steps to get intermediate representations | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if name.startswith("visual.trunk.blocks"): | |
| depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
| if depth >= self.starting_depth: | |
| param.requires_grad = True | |
| # --- Activate the hooks for the specific layers --- | |
| for layer in range(self.starting_depth, len(blocks)): | |
| blocks[layer].attn.forward = types.MethodType( | |
| hooked_attention_forward, blocks[layer].attn | |
| ) | |
| blocks[layer].forward = types.MethodType( | |
| hooked_resblock_forward, blocks[layer] | |
| ) | |
| def _activate_timm_self_attention_hooks(self): | |
| # Adjusting to use the correct structure | |
| blocks = self.visual.trunk.blocks | |
| # ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
| # Necessary steps to get intermediate representations | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if name.startswith("visual.trunk.blocks"): | |
| depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
| if depth >= self.starting_depth: | |
| param.requires_grad = True | |
| # --- Activate the hooks for the specific layers --- | |
| for layer in range(self.starting_depth, len(blocks)): | |
| blocks[layer].attn.forward = types.MethodType( | |
| hooked_attention_timm_forward, blocks[layer].attn | |
| ) | |
| blocks[layer].forward = types.MethodType( | |
| hooked_resblock_timm_forward, blocks[layer] | |
| ) | |
| def _activate_att_pool_hooks(self, layer_index): | |
| # ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
| # Necessary steps to get intermediate representations | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if name.startswith("visual.transformer.resblocks"): | |
| # get the depth | |
| depth = int( | |
| name.split("visual.transformer.resblocks.")[-1].split(".")[0] | |
| ) | |
| if depth >= self.starting_depth: | |
| param.requires_grad = True | |
| # --- Activate the hooks for the specific layers --- | |
| for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): | |
| self.visual.transformer.resblocks[layer].forward = types.MethodType( | |
| hooked_resblock_forward, self.visual.transformer.resblocks[layer] | |
| ) | |
| # --- Apply hook on the attentional pooler --- | |
| self.visual.attn_pool.attn.forward = types.MethodType( | |
| hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn | |
| ) | |
| def _activate_timm_attn_pool_hooks(self, layer_index): | |
| # Ensure all components are present before attaching hooks | |
| if ( | |
| not hasattr(self.visual.trunk, "attn_pool") | |
| or self.visual.trunk.attn_pool is None | |
| ): | |
| raise ValueError("Attentional pooling not found in TimmModel") | |
| self.visual.trunk.attn_pool.forward = types.MethodType( | |
| hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool | |
| ) | |
| for block in self.visual.trunk.blocks: | |
| if hasattr(block, "attn"): | |
| block.attn.forward = types.MethodType( | |
| hooked_attention_forward, block.attn | |
| ) | |
| # --- Deactivate gradient for module that don't need it --- | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| if name.startswith("visual.trunk.attn_pool"): | |
| param.requires_grad = True | |
| if name.startswith("visual.trunk.blocks"): | |
| # get the depth | |
| depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
| if depth >= self.starting_depth: | |
| param.requires_grad = True | |
| # --- Activate the hooks for the specific layers by modifying the block's forward --- | |
| for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): | |
| self.visual.trunk.blocks[layer].forward = types.MethodType( | |
| hooked_resblock_timm_forward, self.visual.trunk.blocks[layer] | |
| ) | |
| self.visual.trunk.attn_pool.forward = types.MethodType( | |
| hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool | |
| ) | |
| def compute_legrad(self, text_embedding, image=None, apply_correction=True): | |
| if "clip" in self.model_type: | |
| return self.compute_legrad_clip(text_embedding, image) | |
| elif "siglip" in self.model_type: | |
| return self.compute_legrad_siglip( | |
| text_embedding, image, apply_correction=apply_correction | |
| ) | |
| elif "coca" in self.model_type: | |
| return self.compute_legrad_coca(text_embedding, image) | |
| def compute_legrad_clip(self, text_embedding, image=None): | |
| num_prompts = text_embedding.shape[0] | |
| if image is not None: | |
| # Ensure the image is passed through the model to get the intermediate features | |
| _ = self.encode_image(image) | |
| blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) | |
| image_features_list = [] | |
| for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): | |
| # [num_patch, batch, dim] | |
| intermediate_feat = blocks_list[layer].feat_post_mlp | |
| # Mean over the patch tokens | |
| intermediate_feat = intermediate_feat.mean(dim=1) | |
| intermediate_feat = self.visual.head( | |
| self.visual.trunk.norm(intermediate_feat) | |
| ) | |
| intermediate_feat = F.normalize(intermediate_feat, dim=-1) | |
| image_features_list.append(intermediate_feat) | |
| num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1 | |
| w = h = int(math.sqrt(num_tokens)) | |
| # ----- Get explainability map | |
| accum_expl_map = 0 | |
| for layer, (blk, img_feat) in enumerate( | |
| zip(blocks_list[self.starting_depth :], image_features_list) | |
| ): | |
| self.visual.zero_grad() | |
| sim = text_embedding @ img_feat.transpose(-1, -2) # [1, 1] | |
| one_hot = ( | |
| F.one_hot(torch.arange(0, num_prompts)) | |
| .float() | |
| .requires_grad_(True) | |
| .to(text_embedding.device) | |
| ) | |
| one_hot = torch.sum(one_hot * sim) | |
| # [b, num_heads, N, N] | |
| attn_map = blocks_list[self.starting_depth + layer].attn.attention_map | |
| # -------- Get explainability map -------- | |
| # [batch_size * num_heads, N, N] | |
| grad = torch.autograd.grad( | |
| one_hot, [attn_map], retain_graph=True, create_graph=True | |
| )[0] | |
| # grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts) # separate batch and attn heads | |
| grad = torch.clamp(grad, min=0.0) | |
| # average attn over [CLS] + patch tokens | |
| image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] | |
| expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h) | |
| # [B, 1, H, W] | |
| expl_map = F.interpolate( | |
| expl_map, scale_factor=self.patch_size, mode="bilinear" | |
| ) | |
| accum_expl_map += expl_map | |
| # Min-Max Norm | |
| accum_expl_map = min_max(accum_expl_map) | |
| return accum_expl_map | |
| def compute_legrad_coca(self, text_embedding, image=None): | |
| if image is not None: | |
| _ = self.encode_image(image) | |
| blocks_list = list( | |
| dict(self.visual.transformer.resblocks.named_children()).values() | |
| ) | |
| image_features_list = [] | |
| for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): | |
| intermediate_feat = self.visual.transformer.resblocks[ | |
| layer | |
| ].feat_post_mlp # [num_patch, batch, dim] | |
| intermediate_feat = intermediate_feat.permute( | |
| 1, 0, 2 | |
| ) # [batch, num_patch, dim] | |
| image_features_list.append(intermediate_feat) | |
| num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 | |
| w = h = int(math.sqrt(num_tokens)) | |
| # ----- Get explainability map | |
| accum_expl_map = 0 | |
| for layer, (blk, img_feat) in enumerate( | |
| zip(blocks_list[self.starting_depth :], image_features_list) | |
| ): | |
| self.visual.zero_grad() | |
| # --- Apply attn_pool --- | |
| image_embedding = self.visual.attn_pool(img_feat)[ | |
| :, 0 | |
| ] # we keep only the first pooled token as it is only this one trained with the contrastive loss | |
| image_embedding = image_embedding @ self.visual.proj | |
| sim = text_embedding @ image_embedding.transpose(-1, -2) # [1, 1] | |
| one_hot = torch.sum(sim) | |
| attn_map = ( | |
| self.visual.attn_pool.attn.attention_maps | |
| ) # [num_heads, num_latent, num_patch] | |
| # -------- Get explainability map -------- | |
| grad = torch.autograd.grad( | |
| one_hot, [attn_map], retain_graph=True, create_graph=True | |
| )[ | |
| 0 | |
| ] # [num_heads, num_latent, num_patch] | |
| grad = torch.clamp(grad, min=0.0) | |
| image_relevance = grad.mean(dim=0)[ | |
| 0, 1: | |
| ] # average attn over heads + select first latent | |
| expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h) | |
| expl_map = F.interpolate( | |
| expl_map, scale_factor=self.patch_size, mode="bilinear" | |
| ) # [B, 1, H, W] | |
| accum_expl_map += expl_map | |
| # Min-Max Norm | |
| accum_expl_map = (accum_expl_map - accum_expl_map.min()) / ( | |
| accum_expl_map.max() - accum_expl_map.min() | |
| ) | |
| return accum_expl_map | |
| def _init_empty_embedding(self): | |
| if not hasattr(self, "empty_embedding"): | |
| # For the moment only SigLIP is supported & they all have the same tokenizer | |
| _tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP") | |
| empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device) | |
| empty_embedding = self.encode_text(empty_text) | |
| empty_embedding = F.normalize(empty_embedding, dim=-1) | |
| self.empty_embedding = empty_embedding.t() | |
| def compute_legrad_siglip( | |
| self, | |
| text_embedding, | |
| image=None, | |
| apply_correction=True, | |
| correction_threshold=0.8, | |
| ): | |
| # --- Forward CLIP --- | |
| blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) | |
| if image is not None: | |
| _ = self.encode_image(image) # [bs, num_patch, dim] bs=num_masks | |
| image_features_list = [] | |
| for blk in blocks_list[self.starting_depth :]: | |
| intermediate_feat = blk.feat_post_mlp | |
| image_features_list.append(intermediate_feat) | |
| num_tokens = blocks_list[-1].feat_post_mlp.shape[1] | |
| w = h = int(math.sqrt(num_tokens)) | |
| if apply_correction: | |
| self._init_empty_embedding() | |
| accum_expl_map_empty = 0 | |
| accum_expl_map = 0 | |
| for layer, (blk, img_feat) in enumerate( | |
| zip(blocks_list[self.starting_depth :], image_features_list) | |
| ): | |
| self.zero_grad() | |
| pooled_feat = self.visual.trunk.attn_pool(img_feat) | |
| pooled_feat = F.normalize(pooled_feat, dim=-1) | |
| # -------- Get explainability map -------- | |
| sim = text_embedding @ pooled_feat.transpose(-1, -2) # [num_mask, num_mask] | |
| one_hot = torch.sum(sim) | |
| grad = torch.autograd.grad( | |
| one_hot, | |
| [self.visual.trunk.attn_pool.attn_probs], | |
| retain_graph=True, | |
| create_graph=True, | |
| )[0] | |
| grad = torch.clamp(grad, min=0.0) | |
| image_relevance = grad.mean(dim=1)[ | |
| :, 0 | |
| ] # average attn over [CLS] + patch tokens | |
| expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h) | |
| accum_expl_map += expl_map | |
| if apply_correction: | |
| # -------- Get empty explainability map -------- | |
| sim_empty = pooled_feat @ self.empty_embedding | |
| one_hot_empty = torch.sum(sim_empty) | |
| grad_empty = torch.autograd.grad( | |
| one_hot_empty, | |
| [self.visual.trunk.attn_pool.attn_probs], | |
| retain_graph=True, | |
| create_graph=True, | |
| )[0] | |
| grad_empty = torch.clamp(grad_empty, min=0.0) | |
| image_relevance_empty = grad_empty.mean(dim=1)[ | |
| :, 0 | |
| ] # average attn over heads + select query's row | |
| expl_map_empty = rearrange( | |
| image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h | |
| ) | |
| accum_expl_map_empty += expl_map_empty | |
| if apply_correction: | |
| heatmap_empty = min_max(accum_expl_map_empty) | |
| accum_expl_map[heatmap_empty > correction_threshold] = 0 | |
| Res = min_max(accum_expl_map) | |
| Res = F.interpolate( | |
| Res, scale_factor=self.patch_size, mode="bilinear" | |
| ) # [B, 1, H, W] | |
| return Res | |
| class LePreprocess(nn.Module): | |
| """ | |
| Modify OpenCLIP preprocessing to accept arbitrary image size. | |
| """ | |
| def __init__(self, preprocess, image_size): | |
| super(LePreprocess, self).__init__() | |
| self.transform = Compose( | |
| [ | |
| Resize( | |
| (image_size, image_size), interpolation=InterpolationMode.BICUBIC | |
| ), | |
| preprocess.transforms[-3], | |
| preprocess.transforms[-2], | |
| preprocess.transforms[-1], | |
| ] | |
| ) | |
| def forward(self, image): | |
| return self.transform(image) | |