Spaces:
Runtime error
Runtime error
Hugo Flores
commited on
Commit
·
9439b64
1
Parent(s):
b862275
fix sampling logic for paella
Browse files- vampnet/modules/base.py +13 -18
vampnet/modules/base.py
CHANGED
|
@@ -153,7 +153,6 @@ class VampBase(at.ml.BaseModel):
|
|
| 153 |
sampling_steps: int = 12,
|
| 154 |
start_tokens: Optional[torch.Tensor] = None,
|
| 155 |
mask: Optional[torch.Tensor] = None,
|
| 156 |
-
device: str = "cpu",
|
| 157 |
temperature: Union[float, Tuple[float, float]] = 1.0,
|
| 158 |
top_k: int = None,
|
| 159 |
sample: str = "gumbel",
|
|
@@ -164,7 +163,8 @@ class VampBase(at.ml.BaseModel):
|
|
| 164 |
typical_min_tokens=1,
|
| 165 |
return_signal=True,
|
| 166 |
):
|
| 167 |
-
|
|
|
|
| 168 |
if renoise_steps == None:
|
| 169 |
renoise_steps = sampling_steps - 1
|
| 170 |
|
|
@@ -186,7 +186,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 186 |
if self.noise_mode == "noise":
|
| 187 |
z = torch.randint(
|
| 188 |
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
| 189 |
-
).to(device)
|
| 190 |
elif self.noise_mode == "mask":
|
| 191 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
| 192 |
else:
|
|
@@ -197,19 +197,14 @@ class VampBase(at.ml.BaseModel):
|
|
| 197 |
assert z.shape[0] == 1, f"batch size must be 1"
|
| 198 |
|
| 199 |
if mask is None:
|
| 200 |
-
mask = torch.ones(z.shape[0], z.shape[-1]).to(device).int()
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
z.shape[0],
|
| 205 |
-
z.shape[-1],
|
| 206 |
-
), f"mask must be shape (batch, seq_len), got {mask.shape}"
|
| 207 |
-
mask = mask[:, None, :]
|
| 208 |
-
mask = mask.repeat(1, z.shape[1], 1)
|
| 209 |
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
|
| 214 |
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
| 215 |
z_init = z.clone()
|
|
@@ -228,8 +223,8 @@ class VampBase(at.ml.BaseModel):
|
|
| 228 |
|
| 229 |
z = self.sample_from_logits(
|
| 230 |
logits,
|
| 231 |
-
|
| 232 |
-
|
| 233 |
sample=sample,
|
| 234 |
typical_filtering=typical_filtering,
|
| 235 |
typical_mass=typical_mass,
|
|
@@ -323,7 +318,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 323 |
# how many codebooks are we inferring vs conditioning on?
|
| 324 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
| 325 |
|
| 326 |
-
for i in
|
| 327 |
# our current temperature
|
| 328 |
tmpt = temperature[i]
|
| 329 |
|
|
@@ -450,7 +445,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 450 |
probs = torch.softmax(logits, dim=-1)
|
| 451 |
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
| 452 |
elif sample == "argmax":
|
| 453 |
-
inferred = torch.softmax(
|
| 454 |
elif sample == "gumbel":
|
| 455 |
inferred = gumbel_sample(logits, dim=-1)
|
| 456 |
else:
|
|
|
|
| 153 |
sampling_steps: int = 12,
|
| 154 |
start_tokens: Optional[torch.Tensor] = None,
|
| 155 |
mask: Optional[torch.Tensor] = None,
|
|
|
|
| 156 |
temperature: Union[float, Tuple[float, float]] = 1.0,
|
| 157 |
top_k: int = None,
|
| 158 |
sample: str = "gumbel",
|
|
|
|
| 163 |
typical_min_tokens=1,
|
| 164 |
return_signal=True,
|
| 165 |
):
|
| 166 |
+
|
| 167 |
+
r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
|
| 168 |
if renoise_steps == None:
|
| 169 |
renoise_steps = sampling_steps - 1
|
| 170 |
|
|
|
|
| 186 |
if self.noise_mode == "noise":
|
| 187 |
z = torch.randint(
|
| 188 |
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
| 189 |
+
).to(self.device)
|
| 190 |
elif self.noise_mode == "mask":
|
| 191 |
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
| 192 |
else:
|
|
|
|
| 197 |
assert z.shape[0] == 1, f"batch size must be 1"
|
| 198 |
|
| 199 |
if mask is None:
|
| 200 |
+
mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
|
| 201 |
+
mask = mask[:, None, :]
|
| 202 |
+
mask = mask.repeat(1, z.shape[1], 1)
|
| 203 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
| 205 |
|
| 206 |
+
|
| 207 |
+
z_true = z.clone()
|
| 208 |
|
| 209 |
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
| 210 |
z_init = z.clone()
|
|
|
|
| 223 |
|
| 224 |
z = self.sample_from_logits(
|
| 225 |
logits,
|
| 226 |
+
top_k=top_k,
|
| 227 |
+
temperature=tmpt,
|
| 228 |
sample=sample,
|
| 229 |
typical_filtering=typical_filtering,
|
| 230 |
typical_mass=typical_mass,
|
|
|
|
| 318 |
# how many codebooks are we inferring vs conditioning on?
|
| 319 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
| 320 |
|
| 321 |
+
for i in range(sampling_steps):
|
| 322 |
# our current temperature
|
| 323 |
tmpt = temperature[i]
|
| 324 |
|
|
|
|
| 445 |
probs = torch.softmax(logits, dim=-1)
|
| 446 |
inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
|
| 447 |
elif sample == "argmax":
|
| 448 |
+
inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
| 449 |
elif sample == "gumbel":
|
| 450 |
inferred = gumbel_sample(logits, dim=-1)
|
| 451 |
else:
|