[feature] support decode from embeddings
Browse files- modeling_vqvae.py +17 -0
modeling_vqvae.py
CHANGED
|
@@ -78,6 +78,11 @@ class VQVAE(PreTrainedModel):
|
|
| 78 |
h = self.post_vq_conv(shift_dim(h, -1, 1))
|
| 79 |
return self.decoder(h)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def forward(self, x):
|
| 82 |
z = self.pre_vq_conv(self.encoder(x))
|
| 83 |
vq_output = self.codebook(z)
|
|
@@ -159,6 +164,18 @@ class Codebook(nn.Module):
|
|
| 159 |
self.z_avg.data.copy_(_k_rand)
|
| 160 |
self.N.data.copy_(torch.ones(self.n_codes))
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
def forward(self, z):
|
| 163 |
# z: [b, c, t, h, w]
|
| 164 |
if self._need_init and self.training:
|
|
|
|
| 78 |
h = self.post_vq_conv(shift_dim(h, -1, 1))
|
| 79 |
return self.decoder(h)
|
| 80 |
|
| 81 |
+
def decode_from_embeddings(self, embeddings):
|
| 82 |
+
# embeddings: [b, c, t, h, w]
|
| 83 |
+
encodings = self.codebook.search_indices(embeddings)
|
| 84 |
+
return self.decode(encodings)
|
| 85 |
+
|
| 86 |
def forward(self, x):
|
| 87 |
z = self.pre_vq_conv(self.encoder(x))
|
| 88 |
vq_output = self.codebook(z)
|
|
|
|
| 164 |
self.z_avg.data.copy_(_k_rand)
|
| 165 |
self.N.data.copy_(torch.ones(self.n_codes))
|
| 166 |
|
| 167 |
+
def search_indices(self, z):
|
| 168 |
+
# z: [b, c, t, h, w]
|
| 169 |
+
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
|
| 170 |
+
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
|
| 171 |
+
- 2 * flat_inputs @ self.embeddings.t() \
|
| 172 |
+
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
|
| 173 |
+
|
| 174 |
+
encoding_indices = torch.argmin(distances, dim=1)
|
| 175 |
+
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
|
| 176 |
+
return encoding_indices
|
| 177 |
+
|
| 178 |
+
|
| 179 |
def forward(self, z):
|
| 180 |
# z: [b, c, t, h, w]
|
| 181 |
if self._need_init and self.training:
|