Spaces:
Sleeping
Sleeping
| import warnings | |
| from dataclasses import dataclass | |
| from typing import List | |
| import torch | |
| from einops import rearrange | |
| from PIL import Image | |
| from torch import nn | |
| from transformers.models.bert import BertConfig, BertModel | |
| from transformers.models.bloom import BloomConfig, BloomForCausalLM, BloomTokenizerFast | |
| from transformers.models.convnext import ConvNextImageProcessor | |
| from transformers.models.convnextv2 import ConvNextV2Config | |
| from transformers.models.convnextv2.modeling_convnextv2 import ConvNextV2Model | |
| # Copied from | |
| # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/collator.py#L13-L32 | |
| class ImageFeatureCollator: | |
| image_processor: ConvNextImageProcessor | |
| image_model: ConvNextV2Model | |
| def __call__(self, batch_image: List[Image.Image]): | |
| return self.tensorize_batch_image(batch_image=batch_image) | |
| def tensorize_batch_image(self, batch_image: List[Image.Image]): | |
| image_inputs = self.image_processor(batch_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| image_outputs = self.image_model(**image_inputs) | |
| image_features = image_outputs["last_hidden_state"] | |
| image_features = rearrange(image_features, "b c h w -> b h w c") | |
| image_features = rearrange(image_features, "b h w c -> b (h w) c") | |
| image_attentions = torch.ones(image_features.size()[:-1], dtype=torch.long) | |
| return image_features, image_attentions | |
| # Copied from | |
| # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/cutie.py#L6C1-L78C28 | |
| class IdentityForBertEmbeddings(nn.Module): | |
| """To skip all BertEmbeddings because another text embeddings provided by another model are used""" | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| def forward(self, **bert_embeddings_args): | |
| inputs_embeds = bert_embeddings_args.get("inputs_embeds", None) | |
| return inputs_embeds | |
| class Cutie(nn.Module): | |
| """Cutie - Qt - Query Transformer - Q-Former | |
| Cutie is motivated by the underlying theoretical foundations of Q-Former presented in BLIP-2 https://arxiv.org/abs/2301.12597 | |
| It should be noted that Cutie differs from the specific approach described in the aforementioned paper | |
| Both Cutie and Q-former have Query tokens. | |
| Cutie uses the same unmodified BERT. | |
| Q-former modifies BERT to behave differently on some tasks. | |
| """ | |
| def __init__( | |
| self, | |
| bert_config: BertConfig, | |
| max_query_length: int = 32, | |
| language_model_ignore_label: int = -100, | |
| ) -> None: | |
| assert bert_config.is_decoder, "BERT must be a decoder" | |
| assert bert_config.add_cross_attention, "BERT must have cross attention layer" | |
| super().__init__() | |
| self.bert_model = BertModel(bert_config, add_pooling_layer=False) | |
| self.bert_model.embeddings = IdentityForBertEmbeddings() | |
| self.query_tokens = nn.Parameter( | |
| torch.zeros(1, max_query_length, bert_config.hidden_size) | |
| ) | |
| self.query_tokens.data.normal_(mean=0.0, std=bert_config.initializer_range) | |
| self.query_attentions = torch.ones( | |
| self.query_tokens.size()[:-1], dtype=torch.long | |
| ) | |
| self.query_labels = torch.full( | |
| self.query_tokens.size()[:-1], language_model_ignore_label, dtype=torch.long | |
| ) | |
| def forward( | |
| self, | |
| image_features: torch.Tensor, | |
| image_attentions: torch.Tensor, | |
| instruction_embeds: torch.Tensor, | |
| instruction_attention_mask: torch.Tensor, | |
| ): | |
| batch_size = image_features.size(0) | |
| query_tokens = self.query_tokens.expand(batch_size, -1, -1).to( | |
| self.query_tokens.device | |
| ) | |
| query_attentions = self.query_attentions.expand(batch_size, -1).to( | |
| self.query_tokens.device | |
| ) | |
| cat_embeds = torch.cat([query_tokens, instruction_embeds], dim=1) | |
| cat_attentions = torch.cat( | |
| [query_attentions, instruction_attention_mask], dim=1 | |
| ) | |
| bert_outputs = self.bert_model( | |
| inputs_embeds=cat_embeds, | |
| attention_mask=cat_attentions, | |
| encoder_hidden_states=image_features, | |
| encoder_attention_mask=image_attentions, | |
| ) | |
| cutie_output = bert_outputs.last_hidden_state[:, : query_tokens.size(1), :] | |
| return cutie_output | |
| # Copied from | |
| # https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/visual_bloom.py#L12C1-L162C31 | |
| class VisualBloom(nn.Module): | |
| """A BLOOM-based model that can take image inputs""" | |
| def __init__( | |
| self, | |
| convnextv2_config: ConvNextV2Config, | |
| bert_config: BertConfig, | |
| bloom_config: BloomConfig, | |
| bloom_name: str, | |
| use_frozen_bloom: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| if ( | |
| convnextv2_config.hidden_sizes[-1] | |
| == bert_config.hidden_size | |
| == bloom_config.hidden_size | |
| ): | |
| self.use_projection = False | |
| warnings.warn( | |
| "All embedding dimensions are equal. No linear projection layers are created." | |
| ) | |
| else: | |
| self.use_projection = True | |
| self.text_to_cutie = nn.Linear( | |
| bloom_config.hidden_size, bert_config.hidden_size | |
| ) | |
| self.image_to_cutie = nn.Linear( | |
| convnextv2_config.hidden_sizes[-1], bert_config.hidden_size | |
| ) | |
| self.cutie_to_text = nn.Linear( | |
| bert_config.hidden_size, bloom_config.hidden_size | |
| ) | |
| self.cutie_model = Cutie(bert_config) | |
| # Load and freeze BLOOM model | |
| if use_frozen_bloom: | |
| self.bloom_model = BloomForCausalLM.from_pretrained(bloom_name) | |
| for param in self.bloom_model.parameters(): | |
| param.requires_grad = False | |
| else: | |
| self.bloom_model = BloomForCausalLM(bloom_config) | |
| def forward( | |
| self, | |
| # Image model outputs - Q-former inputs | |
| image_features: torch.Tensor, | |
| image_attentions: torch.Tensor, | |
| # Q-former inputs | |
| instruction_input_ids: torch.Tensor, | |
| instruction_attention_mask: torch.Tensor, | |
| # Frozen language model inputs | |
| language_model_input_ids: torch.Tensor, | |
| language_model_attention_mask: torch.Tensor, | |
| language_model_labels: torch.Tensor, | |
| ): | |
| instruction_embeds = self.bloom_model.transformer.word_embeddings( | |
| instruction_input_ids | |
| ) | |
| instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
| instruction_embeds | |
| ) | |
| if self.use_projection: | |
| image_features = self.image_to_cutie(image_features) | |
| instruction_embeds = self.text_to_cutie(instruction_embeds) | |
| cutie_output = self.cutie_model( | |
| image_features=image_features, | |
| image_attentions=image_attentions, | |
| instruction_embeds=instruction_embeds, | |
| instruction_attention_mask=instruction_attention_mask, | |
| ) | |
| if self.use_projection: | |
| cutie_output = self.cutie_to_text(cutie_output) | |
| cutie_attentions = self.cutie_model.query_attentions.expand( | |
| cutie_output.size(0), -1 | |
| ).to(cutie_output.device) | |
| cutie_labels = self.cutie_model.query_labels.expand( | |
| cutie_output.size(0), -1 | |
| ).to(cutie_output.device) | |
| language_model_embeds = self.bloom_model.transformer.word_embeddings( | |
| language_model_input_ids | |
| ) | |
| language_model_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
| language_model_embeds | |
| ) | |
| cat_embeds = torch.cat([cutie_output, language_model_embeds], dim=1) | |
| cat_attentions = torch.cat( | |
| [cutie_attentions, language_model_attention_mask], dim=1 | |
| ) | |
| cat_labels = torch.cat([cutie_labels, language_model_labels], dim=1) | |
| bloom_outputs = self.bloom_model( | |
| inputs_embeds=cat_embeds, attention_mask=cat_attentions, labels=cat_labels | |
| ) | |
| return bloom_outputs | |
| def generate( | |
| self, | |
| # Image model outputs - Q-former inputs | |
| image_features: torch.Tensor, | |
| image_attentions: torch.Tensor, | |
| # Q-former inputs | |
| instruction_input_ids: torch.Tensor, | |
| instruction_attention_mask: torch.Tensor, | |
| ): | |
| instruction_embeds = self.bloom_model.transformer.word_embeddings( | |
| instruction_input_ids | |
| ) | |
| instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
| instruction_embeds | |
| ) | |
| if self.use_projection: | |
| image_features = self.image_to_cutie(image_features) | |
| cutie_instruction_embeds = self.text_to_cutie(instruction_embeds) | |
| cutie_output = self.cutie_model( | |
| image_features=image_features, | |
| image_attentions=image_attentions, | |
| instruction_embeds=cutie_instruction_embeds, | |
| instruction_attention_mask=instruction_attention_mask, | |
| ) | |
| if self.use_projection: | |
| cutie_output = self.cutie_to_text(cutie_output) | |
| cutie_attentions = self.cutie_model.query_attentions.expand( | |
| cutie_output.size(0), -1 | |
| ).to(cutie_output.device) | |
| cat_embeds = torch.cat([cutie_output, instruction_embeds], dim=1) | |
| cat_attentions = torch.cat( | |
| [cutie_attentions, instruction_attention_mask], dim=1 | |
| ) | |
| language_output = self.bloom_model.generate( | |
| inputs_embeds=cat_embeds, | |
| attention_mask=cat_attentions, | |
| max_length=96, | |
| penalty_alpha=0.6, | |
| top_k=4, | |
| ) | |
| return language_output | |
| def setup_models(visual_bloom_state_dict_path: str): | |
| image_model_name = "facebook/convnextv2-large-22k-224" | |
| image_config = ConvNextV2Config.from_pretrained(image_model_name) | |
| image_processor = ConvNextImageProcessor.from_pretrained(image_model_name) | |
| image_model = ConvNextV2Model.from_pretrained(image_model_name) | |
| image_feature_collator = ImageFeatureCollator(image_processor, image_model) | |
| bloom_model_name = "bigscience/bloomz-1b7" | |
| bloom_config = BloomConfig.from_pretrained(bloom_model_name) | |
| tokenizer = BloomTokenizerFast.from_pretrained(bloom_model_name) | |
| tokenizer.padding_side = "right" | |
| bert_config = BertConfig( | |
| hidden_size=1024, | |
| num_hidden_layers=6, | |
| num_attention_heads=16, | |
| is_decoder=True, | |
| add_cross_attention=True, | |
| ) | |
| visual_bloom = VisualBloom( | |
| image_config, | |
| bert_config, | |
| bloom_config, | |
| bloom_model_name, | |
| use_frozen_bloom=False, | |
| ) | |
| visual_bloom.load_state_dict(torch.load(visual_bloom_state_dict_path)) | |
| visual_bloom = visual_bloom.eval() | |
| return { | |
| "visual_bloom": visual_bloom, | |
| "tokenizer": tokenizer, | |
| "image_feature_collator": image_feature_collator, | |
| } | |