Upload ModularStarEncoder
Browse files- modularStarEncoder.py +3 -1
modularStarEncoder.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
-
from
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
@@ -311,6 +311,8 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 311 |
|
| 312 |
|
| 313 |
DEVICE = source_embedding[-1].get_device()
|
|
|
|
|
|
|
| 314 |
|
| 315 |
try:
|
| 316 |
projection_fn = self.starEncoder2.module.projection_heads
|
|
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
+
from config import ModularStarEncoderConfig
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
DEVICE = source_embedding[-1].get_device()
|
| 314 |
+
if DEVICE<0:
|
| 315 |
+
DEVICE = "cpu"
|
| 316 |
|
| 317 |
try:
|
| 318 |
projection_fn = self.starEncoder2.module.projection_heads
|