albertchristopher commited on
Commit
e146d8f
·
verified ·
1 Parent(s): d51faa2

Update src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +65 -132
src/utils.py CHANGED
@@ -1,148 +1,81 @@
1
  # utils.py
2
- from typing import List, Optional
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- DEFAULT_MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
7
-
8
- SYSTEM_PROMPT = (
9
- "You are an expert writing assistant. Summarize the user's text clearly and faithfully. "
10
- "Write 2-4 concise sentences capturing the main points, avoiding speculation and numbers not found in the text."
11
- )
12
-
13
- CHUNK_PROMPT = (
14
- "Summarize the following passage in 1-3 sentences, preserving key facts and names.\n\n"
15
- "PASSAGE:\n{chunk}\n\nSUMMARY:"
16
- )
17
-
18
- REDUCE_PROMPT = (
19
- "You are merging partial summaries of a longer document. Combine them into one cohesive summary "
20
- "of 3-6 sentences covering the overall thrust of the original text, with no contradictions or hallucinations.\n\n"
21
- "PARTIAL SUMMARIES:\n{partials}\n\nFINAL SUMMARY:"
22
  )
 
 
 
 
23
 
24
 
25
- def device_and_dtype():
26
- """Select an appropriate device and dtype based on availability."""
27
- if torch.cuda.is_available():
28
- return "auto", torch.bfloat16
29
- # CPU fallback
30
- return None, torch.float32
31
-
32
-
33
- def load_bitnet_model(model_id: str = DEFAULT_MODEL_ID):
34
- """Load tokenizer and model with reasonable defaults for BitNet."""
35
- device_map, torch_dtype = device_and_dtype()
36
- tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
37
 
38
- # Ensure pad token exists
39
- if tok.pad_token is None:
40
- tok.pad_token = tok.eos_token
41
 
42
- model = AutoModelForCausalLM.from_pretrained(
43
- model_id,
44
- torch_dtype=torch_dtype,
45
- device_map=device_map,
46
- )
47
- return tok, model
48
-
49
-
50
- def chunk_by_tokens(text: str, tokenizer: AutoTokenizer, max_tokens: int = 900, overlap: int = 60) -> List[str]:
51
- """Greedy token chunking with overlap to preserve context for long docs."""
52
- ids = tokenizer.encode(text, add_special_tokens=False)
53
- chunks = []
54
- i = 0
55
- while i < len(ids):
56
- j = min(i + max_tokens, len(ids))
57
- chunk_ids = ids[i:j]
58
- chunks.append(tokenizer.decode(chunk_ids))
59
- if j == len(ids):
60
- break
61
- i = j - overlap # step with overlap
62
- if i < 0:
63
- i = 0
64
- return chunks
65
 
66
 
67
- def generate_summary(
68
- tokenizer,
69
- model,
70
- prompt: str,
71
- max_new_tokens: int = 192,
72
- temperature: float = 0.3,
73
- top_p: float = 0.95,
74
- repetition_penalty: float = 1.05,
75
- ) -> str:
76
- """Generic text generation helper for causal LMs."""
77
- inputs = tokenizer(
78
- prompt,
79
- return_tensors="pt",
80
- padding=True,
81
- truncation=True,
82
- )
83
 
84
- if torch.cuda.is_available():
85
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
86
 
87
- gen_ids = model.generate(
88
- **inputs,
89
- do_sample=(temperature > 0.0),
90
- temperature=temperature,
91
- top_p=top_p,
92
- max_new_tokens=max_new_tokens,
93
- repetition_penalty=repetition_penalty,
94
- eos_token_id=tokenizer.eos_token_id,
95
- pad_token_id=tokenizer.pad_token_id,
96
- )
97
- out = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
98
 
99
- # Return only the completion after the prompt if possible
100
- if out.startswith(prompt):
101
- out = out[len(prompt):]
102
- return out.strip()
103
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- def map_reduce_summarize(
106
- text: str,
107
- tokenizer,
108
- model,
109
- max_chunk_tokens: int = 900,
110
- overlap: int = 60,
111
- chunk_max_new_tokens: int = 128,
112
- final_max_new_tokens: int = 220,
113
- temperature: float = 0.2,
114
- top_p: float = 0.9,
115
- ) -> str:
116
- """Summarize long text by chunking -> summarizing -> reducing."""
117
- chunks = chunk_by_tokens(text, tokenizer, max_tokens=max_chunk_tokens, overlap=overlap)
118
 
119
- # Short texts: single pass
120
- if len(chunks) == 1:
121
- prompt = f"{SYSTEM_PROMPT}\n\n{CHUNK_PROMPT.format(chunk=chunks[0])}"
122
- return generate_summary(tokenizer, model, prompt, max_new_tokens=final_max_new_tokens,
123
- temperature=temperature, top_p=top_p)
124
 
125
- partials: List[str] = []
126
- for ck in chunks:
127
- p = f"{SYSTEM_PROMPT}\n\n{CHUNK_PROMPT.format(chunk=ck)}"
128
- s = generate_summary(
129
- tokenizer,
130
- model,
131
- p,
132
- max_new_tokens=chunk_max_new_tokens,
133
- temperature=temperature,
134
- top_p=top_p,
135
- )
136
- partials.append(s)
137
 
138
- merged = "\n- ".join(partials)
139
- reduce_prompt = f"{SYSTEM_PROMPT}\n\n{REDUCE_PROMPT.format(partials='- ' + merged)}"
140
- final = generate_summary(
141
- tokenizer,
142
- model,
143
- reduce_prompt,
144
- max_new_tokens=final_max_new_tokens,
145
- temperature=max(0.1, temperature - 0.1),
146
- top_p=top_p,
147
- )
148
- return final.strip()
 
1
  # utils.py
2
+ prompt: str,
3
+ max_new_tokens: int = 192,
4
+ temperature: float = 0.3,
5
+ top_p: float = 0.95,
6
+ repetition_penalty: float = 1.05,
7
+ ) -> str:
8
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
9
+ if torch.cuda.is_available():
10
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
11
+ gen_ids = model.generate(
12
+ **inputs,
13
+ do_sample=(temperature > 0.0),
14
+ temperature=temperature,
15
+ top_p=top_p,
16
+ max_new_tokens=max_new_tokens,
17
+ repetition_penalty=repetition_penalty,
18
+ eos_token_id=tokenizer.eos_token_id,
19
+ pad_token_id=tokenizer.pad_token_id,
 
 
20
  )
21
+ out = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
22
+ if out.startswith(prompt):
23
+ out = out[len(prompt):]
24
+ return out.strip()
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
 
28
 
29
+ def map_reduce_summarize(
30
+ text: str,
31
+ tokenizer,
32
+ model,
33
+ max_chunk_tokens: int = 900,
34
+ overlap: int = 60,
35
+ chunk_max_new_tokens: int = 128,
36
+ final_max_new_tokens: int = 220,
37
+ temperature: float = 0.2,
38
+ top_p: float = 0.9,
39
+ ) -> str:
40
+ chunks = chunk_by_tokens(text, tokenizer, max_tokens=max_chunk_tokens, overlap=overlap)
41
+ if len(chunks) == 1:
42
+ prompt = f"{SYSTEM_PROMPT}
 
 
 
 
 
 
 
 
 
43
 
44
 
45
+ {CHUNK_PROMPT.format(chunk=chunks[0])}"
46
+ return generate_summary(tokenizer, model, prompt, max_new_tokens=final_max_new_tokens,
47
+ temperature=temperature, top_p=top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
49
 
50
+ partials: List[str] = []
51
+ for ck in chunks:
52
+ p = f"{SYSTEM_PROMPT}
 
 
 
 
 
 
 
 
53
 
 
 
 
 
54
 
55
+ {CHUNK_PROMPT.format(chunk=ck)}"
56
+ s = generate_summary(
57
+ tokenizer,
58
+ model,
59
+ p,
60
+ max_new_tokens=chunk_max_new_tokens,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ )
64
+ partials.append(s)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ merged = "
68
+ - ".join(partials)
69
+ reduce_prompt = f"{SYSTEM_PROMPT}
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ {REDUCE_PROMPT.format(partials='- ' + merged)}"
73
+ final = generate_summary(
74
+ tokenizer,
75
+ model,
76
+ reduce_prompt,
77
+ max_new_tokens=final_max_new_tokens,
78
+ temperature=max(0.1, temperature - 0.1),
79
+ top_p=top_p,
80
+ )
81
+ return final.strip()