Update modeling_minicpmv.py
#12
by
qianyuchen
- opened
- modeling_minicpmv.py +20 -10
modeling_minicpmv.py
CHANGED
|
@@ -4,11 +4,12 @@ import json
|
|
| 4 |
import timm
|
| 5 |
import torch
|
| 6 |
import torchvision
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 9 |
from torchvision import transforms
|
| 10 |
from transformers import LlamaTokenizer
|
| 11 |
-
|
| 12 |
from .configuration_minicpm import MiniCPMVConfig
|
| 13 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
| 14 |
from .resampler import Resampler
|
|
@@ -74,15 +75,24 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 74 |
|
| 75 |
def get_vision_embedding(self, pixel_values):
|
| 76 |
res = []
|
| 77 |
-
dtype = self.
|
| 78 |
-
|
| 79 |
H, W = pixel_value.shape[-2:]
|
| 80 |
-
|
| 81 |
-
math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
|
| 82 |
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
|
| 83 |
-
if hasattr(
|
| 84 |
-
vision_embedding = vision_embedding[:,
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return torch.vstack(res)
|
| 87 |
|
| 88 |
def get_vllm_embedding(self, data):
|
|
@@ -93,8 +103,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 93 |
if len(pixel_values) > 0:
|
| 94 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
| 95 |
elif self.training:
|
| 96 |
-
dtype = self.
|
| 97 |
-
device = self.
|
| 98 |
dummy_image = torch.zeros(
|
| 99 |
(1, 3, 224, 224), device=device, dtype=dtype
|
| 100 |
)
|
|
|
|
| 4 |
import timm
|
| 5 |
import torch
|
| 6 |
import torchvision
|
| 7 |
+
import deepspeed
|
| 8 |
from PIL import Image
|
| 9 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 10 |
from torchvision import transforms
|
| 11 |
from transformers import LlamaTokenizer
|
| 12 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 13 |
from .configuration_minicpm import MiniCPMVConfig
|
| 14 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
| 15 |
from .resampler import Resampler
|
|
|
|
| 75 |
|
| 76 |
def get_vision_embedding(self, pixel_values):
|
| 77 |
res = []
|
| 78 |
+
dtype = self.llm.lm_head.weight.dtype
|
| 79 |
+
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
|
| 80 |
H, W = pixel_value.shape[-2:]
|
| 81 |
+
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
|
|
|
|
| 82 |
vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
|
| 83 |
+
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
|
| 84 |
+
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
|
| 85 |
+
return resampler(vision_embedding, target_size)
|
| 86 |
+
|
| 87 |
+
if is_deepspeed_zero3_enabled():
|
| 88 |
+
with deepspeed.zero.GatheredParameters(self.vpm.pos_embed, modifier_rank=0):
|
| 89 |
+
for pixel_value in pixel_values:
|
| 90 |
+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
| 91 |
+
res.append(result)
|
| 92 |
+
else:
|
| 93 |
+
for pixel_value in pixel_values:
|
| 94 |
+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
| 95 |
+
res.append(result)
|
| 96 |
return torch.vstack(res)
|
| 97 |
|
| 98 |
def get_vllm_embedding(self, data):
|
|
|
|
| 103 |
if len(pixel_values) > 0:
|
| 104 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
| 105 |
elif self.training:
|
| 106 |
+
dtype = self.llm.lm_head.weight.dtype
|
| 107 |
+
device = self.llm.lm_head.weight.device
|
| 108 |
dummy_image = torch.zeros(
|
| 109 |
(1, 3, 224, 224), device=device, dtype=dtype
|
| 110 |
)
|