Julian Bilcke commited on
Commit
e0464d7
·
1 Parent(s): 16386fc

fix for test.py

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. README.md +3 -0
  3. requirements.txt +1 -0
  4. test.py +31 -9
Dockerfile CHANGED
@@ -59,7 +59,7 @@ RUN pip install -r requirements.txt
59
  COPY --chown=user . .
60
 
61
  # temporary skip model download, to make things faster
62
- # RUN git clone https://huggingface.co/WizardLM/WizardCoder-15B-V1.0
63
 
64
  # help Pythonia by giving it the path to Python
65
  ENV PYTHON_BIN /usr/bin/python3
 
59
  COPY --chown=user . .
60
 
61
  # temporary skip model download, to make things faster
62
+ RUN git clone https://huggingface.co/WizardLM/WizardCoder-15B-V1.0
63
 
64
  # help Pythonia by giving it the path to Python
65
  ENV PYTHON_BIN /usr/bin/python3
README.md CHANGED
@@ -46,8 +46,11 @@ To install those dependencies, first you should create and activate a new virtua
46
  python -m venv .venv
47
  source .venv/bin/activate
48
  pip install --upgrade pip
 
49
  ```
50
 
 
 
51
  Then install the dependencies in it:
52
  ```bash
53
  pip install -r requirements.txt
 
46
  python -m venv .venv
47
  source .venv/bin/activate
48
  pip install --upgrade pip
49
+ pip install torch
50
  ```
51
 
52
+ Note: the Dockerfile will install pytorch itself
53
+
54
  Then install the dependencies in it:
55
  ```bash
56
  pip install -r requirements.txt
requirements.txt CHANGED
@@ -1 +1,2 @@
 
1
  transformers
 
1
+ accelerate
2
  transformers
test.py CHANGED
@@ -1,12 +1,5 @@
1
- import sys
2
- import os
3
- import fire
4
  import torch
5
- import transformers
6
- import json
7
- import jsonlines
8
-
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
10
 
11
  if torch.cuda.is_available():
12
  device = "cuda"
@@ -21,9 +14,27 @@ except:
21
 
22
  print("device: " + device)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  base_model = "./models/WizardCoder-15B-V1.0"
25
  load_8bit = False
26
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(base_model)
28
  if device == "cuda":
29
  model = AutoModelForCausalLM.from_pretrained(
@@ -38,10 +49,21 @@ elif device == "mps":
38
  device_map={"": device},
39
  torch_dtype=torch.float16,
40
  )
 
41
  model.config.pad_token_id = tokenizer.pad_token_id
42
  if not load_8bit:
43
  model.half()
44
 
 
45
  model.eval()
 
46
  if torch.__version__ >= "2" and sys.platform != "win32":
47
- model = torch.compile(model)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
3
 
4
  if torch.cuda.is_available():
5
  device = "cuda"
 
14
 
15
  print("device: " + device)
16
 
17
+
18
+ def evaluate(instruction, tokenizer, model):
19
+ prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
20
+
21
+ ### Instruction:
22
+ {instruction}
23
+
24
+ ### Response:"""
25
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
26
+ input_ids = inputs["input_ids"].to(device)
27
+
28
+ with torch.no_grad():
29
+ generation_output = model.generate(input_ids)
30
+ s = generation_output
31
+ output = tokenizer.decode(s[0], skip_special_tokens=True)
32
+ return output.split("### Response:")[1].strip()
33
+
34
  base_model = "./models/WizardCoder-15B-V1.0"
35
  load_8bit = False
36
 
37
+ print("loading tokenizer..")
38
  tokenizer = AutoTokenizer.from_pretrained(base_model)
39
  if device == "cuda":
40
  model = AutoModelForCausalLM.from_pretrained(
 
49
  device_map={"": device},
50
  torch_dtype=torch.float16,
51
  )
52
+ print("loaded tokenizer")
53
  model.config.pad_token_id = tokenizer.pad_token_id
54
  if not load_8bit:
55
  model.half()
56
 
57
+ print("calling model.eval()")
58
  model.eval()
59
+
60
  if torch.__version__ >= "2" and sys.platform != "win32":
61
+ print("calling torch.compile(model)")
62
+ model = torch.compile(model)
63
+
64
+ instruction = "Write a short summary about AI."
65
+ print("calling evaluate..")
66
+ result = evaluate(instruction, tokenizer, model)
67
+
68
+ print("result: ")
69
+ print(result)