add device
Browse files- dmxMetric.py +6 -1
dmxMetric.py
CHANGED
|
@@ -3,6 +3,7 @@ import lm_eval
|
|
| 3 |
from typing import Union, List, Optional
|
| 4 |
from dmx.compressor.dmx import config_rules, DmxModel
|
| 5 |
import datasets
|
|
|
|
| 6 |
|
| 7 |
_DESCRIPTION = """
|
| 8 |
Evaluation function using lm-eval with d-Matrix integration.
|
|
@@ -54,6 +55,7 @@ class DmxMetric(evaluate.Metric):
|
|
| 54 |
batch_size: Optional[Union[int, str]] = None,
|
| 55 |
max_batch_size: Optional[int] = None,
|
| 56 |
limit: Optional[Union[int, float]] = None,
|
|
|
|
| 57 |
revision: str = "main",
|
| 58 |
trust_remote_code: bool = False,
|
| 59 |
log_samples: bool = True,
|
|
@@ -63,7 +65,10 @@ class DmxMetric(evaluate.Metric):
|
|
| 63 |
"""
|
| 64 |
Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
|
| 65 |
"""
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
|
| 69 |
model_args,
|
|
|
|
| 3 |
from typing import Union, List, Optional
|
| 4 |
from dmx.compressor.dmx import config_rules, DmxModel
|
| 5 |
import datasets
|
| 6 |
+
import torch
|
| 7 |
|
| 8 |
_DESCRIPTION = """
|
| 9 |
Evaluation function using lm-eval with d-Matrix integration.
|
|
|
|
| 55 |
batch_size: Optional[Union[int, str]] = None,
|
| 56 |
max_batch_size: Optional[int] = None,
|
| 57 |
limit: Optional[Union[int, float]] = None,
|
| 58 |
+
device: Optional[str] = None,
|
| 59 |
revision: str = "main",
|
| 60 |
trust_remote_code: bool = False,
|
| 61 |
log_samples: bool = True,
|
|
|
|
| 65 |
"""
|
| 66 |
Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
|
| 67 |
"""
|
| 68 |
+
if device is None:
|
| 69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
|
| 71 |
+
model_args = f"pretrained={model},revision={revision},trust_remote_code={str(trust_remote_code)},device={device}"
|
| 72 |
|
| 73 |
lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
|
| 74 |
model_args,
|