Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.init import ones_, trunc_normal_, zeros_ | |
| from openrec.modeling.common import DropPath, Identity, Mlp | |
| from openrec.modeling.decoders.nrtr_decoder import Embeddings | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, q, kv, key_mask=None): | |
| N, C = kv.shape[1:] | |
| QN = q.shape[1] | |
| q = self.q(q).reshape([-1, QN, self.num_heads, | |
| C // self.num_heads]).transpose(1, 2) | |
| q = q * self.scale | |
| k, v = self.kv(kv).reshape( | |
| [-1, N, 2, self.num_heads, | |
| C // self.num_heads]).permute(2, 0, 3, 1, 4) | |
| attn = q.matmul(k.transpose(2, 3)) | |
| if key_mask is not None: | |
| attn = attn + key_mask.unsqueeze(1) | |
| attn = F.softmax(attn, -1) | |
| if not self.training: | |
| self.attn_map = attn | |
| attn = self.attn_drop(attn) | |
| x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C)) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class EdgeDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=[0.0, 0.0], | |
| act_layer=nn.GELU, | |
| norm_layer='nn.LayerNorm', | |
| epsilon=1e-6, | |
| ): | |
| super().__init__() | |
| self.head_dim = dim // num_heads | |
| self.scale = qk_scale or self.head_dim**-0.5 | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path1 = DropPath( | |
| drop_path[0]) if drop_path[0] > 0.0 else Identity() | |
| self.norm1 = eval(norm_layer)(dim, eps=epsilon) | |
| self.norm2 = eval(norm_layer)(dim, eps=epsilon) | |
| # self.c = nn.Linear(dim, dim*2) | |
| self.p = nn.Linear(dim, dim) | |
| self.cv = nn.Linear(dim, dim) | |
| self.pv = nn.Linear(dim, dim) | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.p_proj = nn.Linear(dim, dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp_ratio = mlp_ratio | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| ) | |
| def forward(self, p, cv, pv): | |
| pN = p.shape[1] | |
| vN = cv.shape[1] | |
| p_shortcut = p | |
| p1 = self.p(p).reshape( | |
| [-1, pN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| cv1 = self.cv(cv).reshape( | |
| [-1, vN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| pv1 = self.pv(pv).reshape( | |
| [-1, vN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N | |
| p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim)) | |
| x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c))) | |
| x = self.norm2(x1 + self.drop_path1(self.mlp(x1))) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=0.0, | |
| act_layer=nn.GELU, | |
| norm_layer='nn.LayerNorm', | |
| epsilon=1e-6, | |
| ): | |
| super().__init__() | |
| self.norm1 = eval(norm_layer)(dim, eps=epsilon) | |
| self.normkv = eval(norm_layer)(dim, eps=epsilon) | |
| self.mixer = CrossAttention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| ) | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
| self.norm2 = eval(norm_layer)(dim, eps=epsilon) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp_ratio = mlp_ratio | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| ) | |
| def forward(self, q, kv, key_mask=None): | |
| x1 = q + self.drop_path( | |
| self.mixer(self.norm1(q), self.normkv(kv), key_mask)) | |
| x = x1 + self.drop_path(self.mlp(self.norm2(x1))) | |
| return x | |
| class CMFFLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=0.0, | |
| act_layer=nn.GELU, | |
| epsilon=1e-6, | |
| ): | |
| super().__init__() | |
| self.normq1 = nn.LayerNorm(dim, eps=epsilon) | |
| self.normkv1 = nn.LayerNorm(dim, eps=epsilon) | |
| self.images_to_question_cross_attn = CrossAttention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| ) | |
| self.normq2 = nn.LayerNorm(dim, eps=epsilon) | |
| self.normkv2 = nn.LayerNorm(dim, eps=epsilon) | |
| self.question_to_images_cross_attn = CrossAttention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| ) | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
| self.normmlp = nn.LayerNorm(dim, eps=epsilon) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| ) | |
| def forward(self, question_f, prompt_f, visual_f, mask=None): | |
| query_add = torch.concat([question_f, prompt_f, visual_f], 1) | |
| query_add = query_add + self.drop_path( | |
| self.images_to_question_cross_attn(self.normq1(query_add), | |
| self.normkv1(prompt_f), mask)) | |
| query_add = query_add + self.drop_path( | |
| self.question_to_images_cross_attn( | |
| self.normq2(query_add), | |
| self.normkv2(query_add[:, -visual_f.shape[1]:, :]))) | |
| query_updated = query_add + self.drop_path( | |
| self.mlp(self.normmlp(query_add))) | |
| question_f_updated = query_updated[:, :question_f.shape[1], :] | |
| prompt_f_updated = query_updated[:, question_f. | |
| shape[1]:-visual_f.shape[1], :] | |
| visual_f_updated = query_updated[:, -visual_f.shape[1]:, :] | |
| return question_f_updated, prompt_f_updated, visual_f_updated | |
| class IGTRDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| dim, | |
| out_channels, | |
| num_layer=2, | |
| drop_path_rate=0.1, | |
| max_len=25, | |
| vis_seq=50, | |
| ch=False, | |
| ar=False, | |
| refine_iter=0, | |
| quesall=True, | |
| next_pred=False, | |
| ds=False, | |
| pos2d=False, | |
| check_search=False, | |
| max_size=[8, 32], | |
| **kwargs): | |
| super(IGTRDecoder, self).__init__() | |
| self.out_channels = out_channels | |
| self.dim = dim | |
| self.max_len = max_len + 3 # max_len + eos + bos | |
| self.ch = ch | |
| self.char_embed = Embeddings(d_model=dim, | |
| vocab=self.out_channels, | |
| scale_embedding=True) | |
| self.ignore_index = out_channels - 1 | |
| self.ar = ar | |
| self.refine_iter = refine_iter | |
| self.bos = self.out_channels - 2 | |
| self.eos = 0 | |
| self.next_pred = next_pred | |
| self.quesall = quesall | |
| self.check_search = check_search | |
| dpr = np.linspace(0, drop_path_rate, num_layer + 2) | |
| self.cmff_decoder = nn.ModuleList([ | |
| CMFFLayer(dim=dim, | |
| num_heads=dim // 32, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=dpr[i]) for i in range(num_layer) | |
| ]) | |
| self.answer_to_question_layer = DecoderLayer(dim=dim, | |
| num_heads=dim // 32, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=dpr[-2]) | |
| self.answer_to_image_layer = DecoderLayer(dim=dim, | |
| num_heads=dim // 32, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=dpr[-1]) | |
| self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| self.ds = ds | |
| self.pos2d = pos2d | |
| if not ds: | |
| self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| trunc_normal_(self.vis_pos_embed, std=0.02) | |
| elif pos2d: | |
| pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim], | |
| dtype=torch.float32) | |
| trunc_normal_(pos_embed, mean=0, std=0.02) | |
| self.vis_pos_embed = nn.Parameter( | |
| pos_embed.transpose(1, 2).reshape(1, dim, max_size[0], | |
| max_size[1]), | |
| requires_grad=True, | |
| ) | |
| self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| self.answer_query = nn.Parameter(torch.zeros([1, 1, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| self.norm_pred = nn.LayerNorm(dim, eps=1e-6) | |
| self.ques1_head = nn.Linear(dim, self.out_channels - 2) | |
| self.ques2_head = nn.Linear(dim, self.max_len, bias=False) | |
| self.ques3_head = nn.Linear(dim, self.max_len - 1) | |
| self.ques4_head = nn.Linear(dim, self.max_len - 1) | |
| trunc_normal_(self.char_pos_embed, std=0.02) | |
| trunc_normal_(self.appear_num_embed, std=0.02) | |
| trunc_normal_(self.answer_query, std=0.02) | |
| trunc_normal_(self.prompt_pos_embed, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| zeros_(m.bias) | |
| elif isinstance(m, nn.LayerNorm): | |
| zeros_(m.bias) | |
| ones_(m.weight) | |
| def no_weight_decay(self): | |
| return { | |
| 'char_pos_embed', 'vis_pos_embed', 'appear_num_embed', | |
| 'answer_query', 'char_embed' | |
| } | |
| def question_encoder(self, targets, train_i): | |
| ( | |
| prompt_pos_idx, | |
| prompt_char_idx, | |
| ques_pos_idx, | |
| ques1_answer, | |
| ques2_char_idx, | |
| ques2_answer, | |
| ques4_char_num, | |
| ques_len, | |
| ques2_len, | |
| prompt_len, | |
| ) = targets | |
| max_ques_len = torch.max(ques_len) | |
| max_ques2_len = torch.max(ques2_len) | |
| max_prompt_len = torch.max(prompt_len) | |
| if self.next_pred and (train_i == 2 or train_i == 3): | |
| prompt_pos = self.prompt_pos_embed | |
| prompt_char_idx = prompt_char_idx[:, :max_prompt_len] | |
| else: | |
| prompt_pos = F.embedding( | |
| prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed | |
| ) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12] | |
| prompt_char_idx = prompt_char_idx[:, :max_prompt_len] | |
| prompt_char = self.char_embed(prompt_char_idx) # bs lp | |
| prompt = prompt_pos + prompt_char | |
| mask_1234 = torch.where(prompt_char_idx == self.ignore_index, | |
| float('-inf'), 0) | |
| ques1 = F.embedding(ques_pos_idx[:, :max_ques_len], | |
| self.char_pos_embed) # bs lq1 dim | |
| ques1_answer = ques1_answer[:, :max_ques_len] | |
| if self.quesall or train_i == 0: | |
| ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1]) | |
| ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len, | |
| 0], | |
| self.char_pos_embed) # bs lq2 dim | |
| ques2_answer = ques2_answer[:, :max_ques2_len] | |
| ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0], | |
| self.ques2_head.weight) | |
| ques4_char = self.char_embed(ques1_answer) | |
| ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len], | |
| self.appear_num_embed) | |
| ques4 = ques4_char + ques4_ap_num | |
| ques4_answer = ques_pos_idx[:, :max_ques_len] | |
| return ( | |
| prompt, | |
| ques1, | |
| ques2, | |
| ques2_head, | |
| ques4, | |
| ques1_answer, | |
| ques2_answer, | |
| ques4_answer, | |
| mask_1234.unsqueeze(1), | |
| ) | |
| else: | |
| return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1) | |
| def forward(self, x, data=None): | |
| if self.training: | |
| return self.forward_train(x, data) | |
| else: | |
| return self.forward_test(x) | |
| def forward_test(self, x): | |
| if not self.ds: | |
| visual_f = x + self.vis_pos_embed | |
| elif self.pos2d: | |
| x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
| visual_f = x.flatten(2).transpose(1, 2) | |
| else: | |
| visual_f = x | |
| bs = x.shape[0] | |
| prompt_bos = self.char_embed( | |
| torch.full( | |
| [bs, 1], self.bos, dtype=torch.long, | |
| device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze( | |
| 0) # BOS prompt | |
| ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1)) | |
| if not self.ar: | |
| if self.check_search: | |
| tgt_in = torch.full((bs, self.max_len), | |
| self.ignore_index, | |
| dtype=torch.long, | |
| device=x.get_device()) | |
| tgt_in[:, 0] = self.bos | |
| logits = [] | |
| for j in range(1, self.max_len): | |
| visual_f_check = visual_f | |
| ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed( | |
| torch.arange(self.out_channels - 2, | |
| device=x.get_device())).unsqueeze(0) | |
| prompt_check = ques_all[:, :j] + self.char_embed( | |
| tgt_in[:, :j]) | |
| # prompt_check = prompt_bos | |
| mask = torch.where( | |
| (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0, | |
| float('-inf'), 0) | |
| for layer in self.cmff_decoder: | |
| ques_check_i, prompt_check, visual_f_check = layer( | |
| ques_check_i, prompt_check, visual_f_check, | |
| mask.unsqueeze(1)) | |
| answer_query_i = self.answer_to_question_layer( | |
| ques_check_i, prompt_check, mask.unsqueeze(1)) | |
| answer_pred_i = self.norm_pred( | |
| self.answer_to_image_layer( | |
| answer_query_i, visual_f_check)) # B, 26, 37 | |
| # the next token probability is in the output's ith token position | |
| fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0) | |
| fc_2 = fc_2.tile([bs, 1, 1]) | |
| p_i = fc_2 @ answer_pred_i.transpose(1, 2) | |
| # p_i = p_i[:, 0, :] | |
| logits.append(p_i) | |
| if j < self.max_len - 1: | |
| # greedy decode. add the next token index to the target input | |
| tgt_in[:, j] = p_i.squeeze().argmax(-1) | |
| # Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
| if (tgt_in == self.eos).any(dim=-1).all(): | |
| break | |
| logits = torch.cat(logits, dim=1) | |
| else: | |
| ques_pd = ques_all[:, 1:, :] | |
| prompt_pd = prompt_bos | |
| visual_f_pd = visual_f | |
| for layer in self.cmff_decoder: | |
| ques_pd, prompt_pd, visual_f_pd = layer( | |
| ques_pd, prompt_pd, visual_f_pd) | |
| answer_query_pd = self.answer_to_question_layer( | |
| ques_pd, prompt_pd) | |
| answer_feats_pd = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_pd, | |
| visual_f_pd)) # B, 26, 37 | |
| logits = self.ques1_head(answer_feats_pd) | |
| elif self.next_pred: | |
| ques_pd_1 = ques_all[:, 1:2, :] | |
| prompt_pd = prompt_bos | |
| visual_f_pd = visual_f | |
| for layer in self.cmff_decoder: | |
| ques_pd_1, prompt_pd, visual_f_pd = layer( | |
| ques_pd_1, prompt_pd, visual_f_pd) | |
| answer_query_pd = self.answer_to_question_layer( | |
| ques_pd_1, prompt_pd) | |
| answer_feats_pd = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_pd, | |
| visual_f_pd)) # B, 26, 37 | |
| logits_pd_1 = self.ques1_head(answer_feats_pd) | |
| ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile( | |
| [bs, 1, 1]) | |
| prompt_next_bos = (self.char_embed( | |
| torch.full( | |
| [bs, 1], self.bos, dtype=torch.long, | |
| device=x.get_device())) + self.prompt_pos_embed[:, :1, :]) | |
| pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1) | |
| pred_prob_list = [pred_prob] | |
| pred_id_list = [pred_id] | |
| for j in range(1, 70): | |
| prompt_next_1 = self.char_embed( | |
| pred_id) + self.prompt_pos_embed[:, | |
| -1 * pred_id.shape[1]:, :] | |
| prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1) | |
| ques_next_i = ques_next | |
| visual_f_i = visual_f | |
| for layer in self.cmff_decoder: | |
| ques_next_i, prompt_next, visual_f_pd = layer( | |
| ques_next_i, prompt_next, visual_f_i) | |
| answer_query_next_i = self.answer_to_question_layer( | |
| ques_next_i, prompt_next) | |
| answer_feats_next_i = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_next_i, | |
| visual_f_i)) # B, 26, 37 | |
| logits_next_i = self.ques1_head(answer_feats_next_i) | |
| # pred_id = logits_next_i.argmax(-1) | |
| pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1) | |
| pred_prob_list.append(pred_prob_i) | |
| pred_id_list.append(pred_id_i) | |
| if (torch.concat(pred_id_list, | |
| 1) == self.eos).any(dim=-1).all(): | |
| break | |
| if pred_id.shape[1] >= 5: | |
| pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1) | |
| else: | |
| pred_id = torch.concat([pred_id, pred_id_i], 1) | |
| return [ | |
| torch.concat(pred_id_list, 1), | |
| torch.concat(pred_prob_list, 1) | |
| ] | |
| else: | |
| tgt_in = torch.full((bs, self.max_len), | |
| self.ignore_index, | |
| dtype=torch.long, | |
| device=x.get_device()) | |
| tgt_in[:, 0] = self.bos | |
| logits = [] | |
| for j in range(1, self.max_len): | |
| visual_f_ar = visual_f | |
| ques_i = ques_all[:, j:j + 1, :] | |
| prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j]) | |
| mask = torch.where( | |
| (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0, | |
| float('-inf'), 0) | |
| for layer in self.cmff_decoder: | |
| ques_i, prompt_ar, visual_f_ar = layer( | |
| ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1)) | |
| answer_query_i = self.answer_to_question_layer( | |
| ques_i, prompt_ar, mask.unsqueeze(1)) | |
| answer_pred_i = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_i, | |
| visual_f_ar)) # B, 26, 37 | |
| # the next token probability is in the output's ith token position | |
| p_i = self.ques1_head(answer_pred_i) | |
| logits.append(p_i) | |
| if j < self.max_len - 1: | |
| # greedy decode. add the next token index to the target input | |
| tgt_in[:, j] = p_i.squeeze().argmax(-1) | |
| # Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
| if (tgt_in == self.eos).any(dim=-1).all(): | |
| break | |
| logits = torch.cat(logits, dim=1) | |
| if self.refine_iter > 0: | |
| pred_probs, pred_idxs = F.softmax(logits, -1).max(-1) | |
| for i in range(self.refine_iter): | |
| mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1 | |
| ques_check_all = self.char_embed( | |
| pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :] | |
| prompt_check = prompt_bos | |
| visual_f_check = visual_f | |
| ques_check = ques_check_all | |
| for layer in self.cmff_decoder: | |
| ques_check, prompt_check, visual_f_check = layer( | |
| ques_check, prompt_check, visual_f_check) | |
| answer_query_check = self.answer_to_question_layer( | |
| ques_check, prompt_check) | |
| answer_pred_check = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_check, | |
| visual_f_check)) # B, 26, 37 | |
| ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] + | |
| 1, :] | |
| ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1]) | |
| answer2_pred = answer_pred_check.matmul( | |
| ques2_head.transpose(1, 2)) | |
| diag_mask = torch.eye(answer2_pred.shape[1], | |
| device=x.get_device()).unsqueeze(0).tile( | |
| [bs, 1, 1]) | |
| answer2_pred = F.sigmoid( | |
| (answer2_pred * diag_mask).sum(-1)) * mask_check | |
| check_result = answer2_pred < 0.9 # pred_probs < 0.99 | |
| prompt_refine = torch.concat([prompt_bos, ques_check_all], 1) | |
| mask_refine = torch.where( | |
| check_result, float('-inf'), 0) + torch.where( | |
| (pred_idxs == self.eos).int().cumsum(-1) < 1, 0, | |
| float('-inf')) | |
| mask_refine = torch.concat( | |
| [torch.zeros([bs, 1], device=x.get_device()), mask_refine], | |
| 1).unsqueeze(1) | |
| ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :] | |
| visual_f_refine = visual_f | |
| for layer in self.cmff_decoder: | |
| ques_refine, prompt_refine, visual_f_refine = layer( | |
| ques_refine, prompt_refine, visual_f_refine, | |
| mask_refine) | |
| answer_query_refine = self.answer_to_question_layer( | |
| ques_refine, prompt_refine, mask_refine) | |
| answer_pred_refine = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_refine, | |
| visual_f_refine)) # B, 26, 37 | |
| answer_refine = self.ques1_head(answer_pred_refine) | |
| refine_probs, refine_idxs = F.softmax(answer_refine, | |
| -1).max(-1) | |
| pred_idxs_refine = torch.where(check_result, refine_idxs, | |
| pred_idxs) | |
| pred_idxs = torch.where(mask_check, pred_idxs_refine, | |
| pred_idxs) | |
| pred_probs_refine = torch.where(check_result, refine_probs, | |
| pred_probs) | |
| pred_probs = torch.where(mask_check, pred_probs_refine, | |
| pred_probs) | |
| return [pred_idxs, pred_probs] | |
| return F.softmax(logits, -1) | |
| def forward_train(self, x, targets=None): | |
| bs = x.shape[0] | |
| answer_token = torch.tile(self.answer_query, (bs, 1, 1)) | |
| if self.ch: | |
| ques3 = self.char_embed(targets[7][:, :, | |
| 0]) + answer_token # bs nc dim | |
| ques3_answer = targets[7][:, :, 1] | |
| else: | |
| ques3 = self.char_embed( | |
| torch.arange(self.out_channels - 2, device=x.get_device()) | |
| ).unsqueeze(0) + answer_token # bs nc dim | |
| ques3_answer = targets[7] | |
| loss1_list = [] | |
| loss2_list = [] | |
| loss3_list = [] | |
| loss4_list = [] | |
| sampler1_num = 0 | |
| sampler2_num = 0 | |
| sampler3_num = 0 | |
| sampler4_num = 0 | |
| if not self.ds: | |
| visual_f = x + self.vis_pos_embed | |
| elif self.pos2d: | |
| x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]] | |
| visual_f = x.flatten(2).transpose(1, 2) | |
| else: | |
| visual_f = x | |
| train_i = 0 | |
| for target_ in zip( | |
| targets[1].transpose(0, 1), | |
| targets[2].transpose(0, 1), | |
| targets[3].transpose(0, 1), | |
| targets[4].transpose(0, 1), | |
| targets[5].transpose(0, 1), | |
| targets[6].transpose(0, 1), | |
| targets[8].transpose(0, 1), | |
| targets[9].transpose(0, 1), | |
| targets[10].transpose(0, 1), | |
| targets[11].transpose(0, 1), | |
| ): | |
| # target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \ | |
| # ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len] | |
| visual_f_1234 = visual_f | |
| if self.quesall or train_i == 0: | |
| ( | |
| prompt, | |
| ques1, | |
| ques2, | |
| ques2_head, | |
| ques4, | |
| ques1_answer, | |
| ques2_answer, | |
| ques4_answer, | |
| mask_1234, | |
| ) = self.question_encoder(target_, train_i) | |
| prompt_1234 = prompt | |
| ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1) | |
| for layer in self.cmff_decoder: | |
| ques_1234, prompt_1234, visual_f_1234 = layer( | |
| ques_1234, prompt_1234, visual_f_1234, mask_1234) | |
| answer_query_1234 = self.answer_to_question_layer( | |
| ques_1234, prompt_1234, mask_1234) | |
| answer_feats_1234 = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_1234, | |
| visual_f_1234)) # B, 26, 37 | |
| answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :] | |
| answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:( | |
| ques1.shape[1] + ques2.shape[1]), :] | |
| answer_feats_3 = answer_feats_1234[:, ( | |
| ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :] | |
| answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :] | |
| answer1_pred = self.ques1_head(answer_feats_1) | |
| if train_i == 0: | |
| logits = answer1_pred | |
| n = (ques1_answer != self.ignore_index).sum().item() | |
| loss1 = n * F.cross_entropy( | |
| answer1_pred.flatten(0, 1), | |
| ques1_answer.flatten(0, 1), | |
| ignore_index=self.ignore_index, | |
| reduction='mean', | |
| ) | |
| sampler1_num += n | |
| loss1_list.append(loss1) | |
| answer2_pred = answer_feats_2.matmul(ques2_head.transpose( | |
| 1, 2)) | |
| diag_mask = torch.eye(answer2_pred.shape[1], | |
| device=x.get_device()).unsqueeze(0).tile( | |
| [bs, 1, 1]) | |
| answer2_pred = (answer2_pred * diag_mask).sum(-1) | |
| ques2_answer = ques2_answer.flatten(0, 1) | |
| non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index) | |
| n = non_pad_mask.sum().item() | |
| ques2_answer = torch.where(ques2_answer == self.ignore_index, | |
| 0, ques2_answer) | |
| loss2_none = F.binary_cross_entropy_with_logits( | |
| answer2_pred.flatten(0, 1), ques2_answer, reduction='none') | |
| loss2 = n * loss2_none.masked_select(non_pad_mask).mean() | |
| sampler2_num += n | |
| loss2_list.append(loss2) | |
| answer3_pred = self.ques3_head(answer_feats_3) | |
| n = (ques3_answer != self.ignore_index).sum().item() | |
| loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1), | |
| ques3_answer.flatten(0, 1), | |
| reduction='mean') | |
| sampler3_num += n | |
| loss3_list.append(loss3) | |
| answer4_pred = self.ques4_head(answer_feats_4) | |
| n = (ques4_answer != self.max_len - 1).sum().item() | |
| loss4 = n * F.cross_entropy( | |
| answer4_pred.flatten(0, 1), | |
| ques4_answer.flatten(0, 1), | |
| ignore_index=self.max_len - 1, | |
| reduction='mean', | |
| ) | |
| sampler4_num += n | |
| loss4_list.append(loss4) | |
| else: | |
| prompt, ques1, ques1_answer, mask_1234 = self.question_encoder( | |
| target_, train_i) | |
| prompt_1234 = prompt | |
| for layer in self.cmff_decoder: | |
| ques1, prompt_1234, visual_f_1234 = layer( | |
| ques1, prompt_1234, visual_f_1234, mask_1234) | |
| answer_query_1 = self.answer_to_question_layer( | |
| ques1, prompt_1234, mask_1234) | |
| answer_feats_1 = self.norm_pred( | |
| self.answer_to_image_layer(answer_query_1, | |
| visual_f_1234)) # B, 26, 37 | |
| answer1_pred = self.ques1_head(answer_feats_1) | |
| n = (ques1_answer != self.ignore_index).sum().item() | |
| loss1 = n * F.cross_entropy( | |
| answer1_pred.flatten(0, 1), | |
| ques1_answer.flatten(0, 1), | |
| ignore_index=self.ignore_index, | |
| reduction='mean', | |
| ) | |
| sampler1_num += n | |
| loss1_list.append(loss1) | |
| train_i += 1 | |
| loss_list = [ | |
| sum(loss1_list) / sampler1_num, | |
| sum(loss2_list) / sampler2_num, | |
| sum(loss3_list) / sampler3_num, | |
| sum(loss4_list) / sampler4_num, | |
| ] | |
| loss = { | |
| 'loss': sum(loss_list), | |
| 'loss1': loss_list[0], | |
| 'loss2': loss_list[1], | |
| 'loss3': loss_list[2], | |
| 'loss4': loss_list[3], | |
| } | |
| return [loss, logits] | |