Spaces:
Runtime error
Runtime error
Commit
·
b0c3beb
1
Parent(s):
6d20fa3
drop device setting for already parallel models
Browse files- dmx_perplexity.py +7 -2
dmx_perplexity.py
CHANGED
|
@@ -40,6 +40,7 @@ Examples:
|
|
| 40 |
46.05925369262695
|
| 41 |
"""
|
| 42 |
|
|
|
|
| 43 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 44 |
class DmxPerplexity(evaluate.Metric):
|
| 45 |
def _info(self):
|
|
@@ -89,9 +90,13 @@ class DmxPerplexity(evaluate.Metric):
|
|
| 89 |
max_seq_len = model.config.n_positions
|
| 90 |
else:
|
| 91 |
max_seq_len = 2048
|
| 92 |
-
|
| 93 |
-
if not hasattr(model, "hf_device_map")
|
|
|
|
|
|
|
| 94 |
model = model.to(device)
|
|
|
|
|
|
|
| 95 |
encodings = tokenizer("\n\n".join(references), return_tensors="pt")
|
| 96 |
|
| 97 |
stride = max_seq_len
|
|
|
|
| 40 |
46.05925369262695
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
|
| 44 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 45 |
class DmxPerplexity(evaluate.Metric):
|
| 46 |
def _info(self):
|
|
|
|
| 90 |
max_seq_len = model.config.n_positions
|
| 91 |
else:
|
| 92 |
max_seq_len = 2048
|
| 93 |
+
|
| 94 |
+
if not hasattr(model, "hf_device_map") and (
|
| 95 |
+
not hasattr(model, "model_parallel") or not model.model_parallel
|
| 96 |
+
):
|
| 97 |
model = model.to(device)
|
| 98 |
+
|
| 99 |
+
model.eval()
|
| 100 |
encodings = tokenizer("\n\n".join(references), return_tensors="pt")
|
| 101 |
|
| 102 |
stride = max_seq_len
|