Update seagull/model/layer.py
Browse files- seagull/model/layer.py +5 -5
seagull/model/layer.py
CHANGED
|
@@ -77,8 +77,8 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
|
|
| 77 |
return mask_feat, global_mask
|
| 78 |
|
| 79 |
def forward(self, feats, masks, cropped_img):
|
| 80 |
-
|
| 81 |
-
|
| 82 |
num_imgs = len(masks)
|
| 83 |
|
| 84 |
for idx in range(num_imgs):
|
|
@@ -108,16 +108,16 @@ class MaskExtractor(nn.Module): # Mask-based Feature Extractor
|
|
| 108 |
mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
|
| 109 |
|
| 110 |
query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
|
| 111 |
-
|
| 112 |
|
| 113 |
cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
| 114 |
global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
|
| 115 |
global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
|
| 116 |
pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
|
| 117 |
|
| 118 |
-
|
| 119 |
|
| 120 |
-
return
|
| 121 |
|
| 122 |
class MaskPooling(nn.Module):
|
| 123 |
def __init__(self):
|
|
|
|
| 77 |
return mask_feat, global_mask
|
| 78 |
|
| 79 |
def forward(self, feats, masks, cropped_img):
|
| 80 |
+
global_features_list = []
|
| 81 |
+
local_features_list = []
|
| 82 |
num_imgs = len(masks)
|
| 83 |
|
| 84 |
for idx in range(num_imgs):
|
|
|
|
| 108 |
mask_feats_linear = self.feat_linear(mask_feats) #(1, q, 4096)
|
| 109 |
|
| 110 |
query_feat = self.final_mlp(torch.cat((global_masks_linear, mask_feats_linear), dim=-1))
|
| 111 |
+
global_features_list.append(query_feat) # global
|
| 112 |
|
| 113 |
cropped_ = cropped_.to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype)
|
| 114 |
global_features = self.global_vit(cropped_).to(device=self.feat_linear.weight.device, dtype=self.feat_linear.weight.dtype) # q, 1, 32, 32
|
| 115 |
global_features = global_features.reshape(-1, 1, 32 * 32) # q, 1, 32 * 32
|
| 116 |
pos_feat = self.mlp(self.sa(global_features, global_features, global_features).squeeze(1)) # q, output
|
| 117 |
|
| 118 |
+
local_features_list.append(pos_feat) #(imgs_num, 1, q, 4096) # local
|
| 119 |
|
| 120 |
+
return global_features_list, local_features_list
|
| 121 |
|
| 122 |
class MaskPooling(nn.Module):
|
| 123 |
def __init__(self):
|