yu-val-weiss
commited on
Commit
·
995725e
1
Parent(s):
c26f589
fix probability logic
Browse files
blimp.py
CHANGED
|
@@ -296,21 +296,22 @@ def get_batch_probabilities(
|
|
| 296 |
with torch.no_grad():
|
| 297 |
outputs = model(**inputs)
|
| 298 |
|
| 299 |
-
|
|
|
|
| 300 |
|
| 301 |
-
#
|
| 302 |
-
log_probs = torch.nn.functional.log_softmax(
|
| 303 |
|
| 304 |
-
#
|
| 305 |
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
| 306 |
|
| 307 |
if batch_size > 1:
|
| 308 |
-
#
|
| 309 |
mask = (labels != tokenizer.pad_token_id).float()
|
| 310 |
token_log_probs *= mask
|
| 311 |
|
| 312 |
# sum log probabilities
|
| 313 |
-
sequence_log_probs =
|
| 314 |
|
| 315 |
probs.extend(sequence_log_probs.cpu().tolist())
|
| 316 |
|
|
|
|
| 296 |
with torch.no_grad():
|
| 297 |
outputs = model(**inputs)
|
| 298 |
|
| 299 |
+
logits = outputs.logits[..., :-1, :].contiguous()
|
| 300 |
+
labels = inputs.input_ids[..., 1:].contiguous()
|
| 301 |
|
| 302 |
+
# compute log probabilities
|
| 303 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 304 |
|
| 305 |
+
# get per-token probability
|
| 306 |
token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
|
| 307 |
|
| 308 |
if batch_size > 1:
|
| 309 |
+
# mask padding tokens
|
| 310 |
mask = (labels != tokenizer.pad_token_id).float()
|
| 311 |
token_log_probs *= mask
|
| 312 |
|
| 313 |
# sum log probabilities
|
| 314 |
+
sequence_log_probs = token_log_probs.sum(dim=1)
|
| 315 |
|
| 316 |
probs.extend(sequence_log_probs.cpu().tolist())
|
| 317 |
|