iamthewalrus67 commited on
Commit
4aa8608
·
1 Parent(s): c0f3227

Calculate score

Browse files
Files changed (1) hide show
  1. app.py +24 -145
app.py CHANGED
@@ -11,8 +11,8 @@ import threading
11
  import spaces
12
  import gradio as gr
13
  import torch
14
- from PIL.Image import Image
15
- from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, TextIteratorStreamer
16
  from kernels import get_kernel
17
  from typing import Any, Optional, Dict
18
 
@@ -31,8 +31,8 @@ login(token=HF_LE_LLM_READ_TOKEN)
31
  # MODEL_ID = "le-llm/lapa-v0.1-instruct"
32
  # MODEL_ID = "le-llm/lapa-v0.1-matt-instruction-5e06"
33
  # MODEL_ID = "le-llm/lapa-v0.1-reprojected"
34
- MODEL_ID = "le-llm/lapa-v0.1.1-instruct"
35
- # MODEL_ID = "le-llm/manipulative-score-model"
36
 
37
  MAX_TOKENS = 4096
38
  TEMPERATURE = 0.7
@@ -56,7 +56,7 @@ def load_model():
56
  except Exception as err: # pragma: no cover - informative fallback
57
  print(f"Warning: AutoProcessor not available ({err}). Falling back to tokenizer.")
58
 
59
- model = AutoModelForCausalLM.from_pretrained(
60
  MODEL_ID,
61
  dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
62
  device_map="auto", # if device == "cuda" else None,
@@ -70,7 +70,7 @@ def load_model():
70
  model, tokenizer, processor, device = load_model()
71
 
72
 
73
- def user(user_message, image_data: Image, history: list):
74
  """Format user message with optional image."""
75
  import io
76
 
@@ -80,44 +80,14 @@ def user(user_message, image_data: Image, history: list):
80
 
81
  stripped_message = user_message.strip()
82
 
83
- # If we have an image, save it to temp file for Gradio display
84
- if image_data is not None:
85
- image_data.thumbnail((IMAGE_MAX_SIZE, IMAGE_MAX_SIZE))
86
-
87
- # Save to temp file for Gradio display
88
- fd, tmp_path = tempfile.mkstemp(suffix=".jpg")
89
- os.close(fd)
90
- image_data.save(tmp_path, format="JPEG")
91
-
92
- # Also encode as base64 for model processing (stored in metadata)
93
- buffered = io.BytesIO()
94
- image_data.save(buffered, format="JPEG")
95
-
96
- # TODO do we leave that message?
97
- text_content = stripped_message if stripped_message else "Опиши це зображення"
98
-
99
- # Store both text and image in a single message with base64 in metadata
100
- updated_history.append({
101
- "role": "user",
102
- "content": text_content
103
- })
104
- updated_history.append({
105
- "role": "user",
106
- "content": {
107
- "path": tmp_path,
108
- "alt_text": "User uploaded image"
109
- },
110
- })
111
- has_content = True
112
- elif stripped_message:
113
- updated_history.append({"role": "user", "content": stripped_message})
114
  has_content = True
115
 
116
  if not has_content:
117
  # Nothing to submit yet; keep inputs unchanged
118
- return user_message, image_data, history
119
 
120
- return "", None, updated_history
121
 
122
 
123
  def append_example_message(x: gr.SelectData, history):
@@ -166,119 +136,29 @@ def _clean_history_for_display(history: list[dict[str, Any]]) -> list[dict[str,
166
 
167
  @spaces.GPU
168
  def bot(
169
- history: list[dict[str, Any]]
170
  ):
171
- """Generate bot response with support for text and images."""
172
 
173
  # Early return if no input
174
- if not history:
175
  return
176
 
177
- # Extract last user message for logging
178
- last_user_msg = next((msg for msg in reversed(history) if msg.get("role") == "user"), None)
179
- user_message_text = _extract_text_from_content(last_user_msg.get("content")) if last_user_msg else ""
180
- print('User message:', user_message_text)
181
-
182
- # Check if any message contains images
183
- has_images = any(
184
- isinstance(msg.get("content"), tuple)
185
- for msg in history
186
- )
187
-
188
- model_inputs = None
189
-
190
- # Use processor if images are present
191
- if processor is not None and has_images:
192
- try:
193
- processor_history = []
194
- for msg in history:
195
- role = msg.get("role", "user")
196
- content = msg.get("content")
197
-
198
- if isinstance(content, str):
199
- processor_history.append({"role": role, "content": [{"type": "text", "text": content}]})
200
- elif isinstance(content, tuple):
201
- formatted_content = []
202
- tmp_path, _ = content
203
- image_input = {
204
- "type": "image",
205
- "url": f"{tmp_path}",
206
- }
207
-
208
- if processor_history[-1].get('role') == 'user':
209
- if isinstance(processor_history[-1].get('content'), str):
210
- previous_message = processor_history[-1].get('content')
211
- formatted_content.append({"type": "text", "text": previous_message})
212
- formatted_content.append(image_input)
213
- processor_history[-1]['content'] = formatted_content
214
- elif isinstance(processor_history[-1].get('content'), list):
215
- processor_history[-1]['content'].append(image_input)
216
- else:
217
- formatted_content.append(image_input)
218
- processor_history.append({"role": role, "content": formatted_content})
219
-
220
- model_inputs = processor.apply_chat_template(
221
- processor_history,
222
- tokenize=True,
223
- return_dict=True,
224
- return_tensors="pt",
225
- add_generation_prompt=True,
226
- ).to(model.device)
227
- print("Using processor for vision input")
228
- except Exception as exc:
229
- print(f"Processor failed: {exc}")
230
- model_inputs = None
231
-
232
- # Fallback to tokenizer for text-only
233
- if model_inputs is None:
234
- # Convert to text-only format for tokenizer
235
- text_history = []
236
- for msg in history:
237
- role = msg.get("role", "user")
238
- content = msg.get("content")
239
- text_content = _extract_text_from_content(content)
240
- if text_content:
241
- text_history.append({"role": role, "content": text_content})
242
-
243
- if text_history:
244
- input_text = tokenizer.apply_chat_template(
245
- text_history,
246
- tokenize=False,
247
- add_generation_prompt=True,
248
- )
249
- if input_text and tokenizer.bos_token:
250
- input_text = input_text.replace(tokenizer.bos_token, "", 1)
251
- model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
252
- print("Using tokenizer for text-only input")
253
 
254
- if model_inputs is None:
255
- return
256
-
257
- # Streamer setup
258
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
259
-
260
- # Run model.generate in background thread
261
- generation_kwargs = dict(
262
- **model_inputs,
263
- max_new_tokens=MAX_TOKENS,
264
- temperature=TEMPERATURE,
265
- top_p=TOP_P,
266
- top_k=64,
267
- do_sample=True,
268
- streamer=streamer,
269
- )
270
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
271
- thread.start()
272
 
273
- history.append({"role": "assistant", "content": ""})
274
- # Yield tokens as they come in
275
- for new_text in streamer:
276
- history[-1]["content"] += new_text
277
- yield _clean_history_for_display(history)
278
 
279
- assistant_message = history[-1]["content"]
280
- logger.log_interaction(user=user_message_text, answer=assistant_message)
281
 
 
 
 
 
282
 
283
  # --- drop-in UI compatible with older Gradio versions ---
284
  import os, tempfile, time
@@ -313,8 +193,7 @@ with gr.Blocks(theme=THEME, css=CSS, fill_height=True) as demo:
313
  gr.HTML(
314
  """
315
  <div id="app-header">
316
- <div class="app-title">✨ LAPA</div>
317
- <div class="app-subtitle">LLM for Ukrainian Language</div>
318
  </div>
319
  """
320
  )
 
11
  import spaces
12
  import gradio as gr
13
  import torch
14
+ import torch.nn.functional as F
15
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, TextIteratorStreamer, AutoModel
16
  from kernels import get_kernel
17
  from typing import Any, Optional, Dict
18
 
 
31
  # MODEL_ID = "le-llm/lapa-v0.1-instruct"
32
  # MODEL_ID = "le-llm/lapa-v0.1-matt-instruction-5e06"
33
  # MODEL_ID = "le-llm/lapa-v0.1-reprojected"
34
+ # MODEL_ID = "le-llm/lapa-v0.1.1-instruct"
35
+ MODEL_ID = "le-llm/manipulative-score-model"
36
 
37
  MAX_TOKENS = 4096
38
  TEMPERATURE = 0.7
 
56
  except Exception as err: # pragma: no cover - informative fallback
57
  print(f"Warning: AutoProcessor not available ({err}). Falling back to tokenizer.")
58
 
59
+ model = AutoModel.from_pretrained(
60
  MODEL_ID,
61
  dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
62
  device_map="auto", # if device == "cuda" else None,
 
70
  model, tokenizer, processor, device = load_model()
71
 
72
 
73
+ def user(user_message, history: list):
74
  """Format user message with optional image."""
75
  import io
76
 
 
80
 
81
  stripped_message = user_message.strip()
82
 
83
+ if stripped_message:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  has_content = True
85
 
86
  if not has_content:
87
  # Nothing to submit yet; keep inputs unchanged
88
+ return user_message, history
89
 
90
+ return "", updated_history
91
 
92
 
93
  def append_example_message(x: gr.SelectData, history):
 
136
 
137
  @spaces.GPU
138
  def bot(
139
+ input: list[dict[str, Any]]
140
  ):
141
+ """Generate bot response with support for text."""
142
 
143
  # Early return if no input
144
+ if not input:
145
  return
146
 
147
+ clean_input = [f"query: {input}"]
148
+ batch_dict = tokenizer(input, max_length=512, padding=True, truncation=True, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ outputs = model(**batch_dict)
151
+ embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ embeddings = F.normalize(embeddings, p=2, dim=1)
154
+ scores = (embeddings[:2] @ embeddings[2:].T) * 100
155
+ return scores.tolist()
 
 
156
 
 
 
157
 
158
+ def average_pool(last_hidden_states: torch.Tensor,
159
+ attention_mask: torch.Tensor) -> torch.Tensor:
160
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
161
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
162
 
163
  # --- drop-in UI compatible with older Gradio versions ---
164
  import os, tempfile, time
 
193
  gr.HTML(
194
  """
195
  <div id="app-header">
196
+ <div class="app-title">🤔 LAPA Quality Estimation</div>
 
197
  </div>
198
  """
199
  )