Tas01 commited on
Commit
4888ca6
·
verified ·
1 Parent(s): bc37146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -112
app.py CHANGED
@@ -273,43 +273,22 @@ class ImageStoryteller:
273
  scenes = [scene['type'] for scene in analysis_result['scenes']]
274
 
275
  # Create a prompt for the LLM
276
- objects_str = ", ".join(objects) # Use top 3 objects
277
  scene_str = scenes[0] if scenes else "general scene"
278
 
279
- # FIXED: Convert creativity_level to float if it's a tuple
280
  if isinstance(creativity_level, (tuple, list)):
281
  creativity_level = float(creativity_level[0])
282
 
283
- # Enhanced prompt with caption generation
284
  if creativity_level > 0.8:
285
- prompt = f"""Based on this image containing {objects_str} in a {scene_str}:
286
-
287
- 1. First, write a catchy 5-7 word YouTube-style caption (engaging, attention-grabbing)
288
- 2. Then, write a creative and imaginative short story (3-4 paragraphs)
289
-
290
- Format exactly like this:
291
- CAPTION: [your catchy caption here]
292
- STORY: [your creative story here]"""
293
  elif creativity_level > 0.5:
294
- prompt = f"""For an image with {objects_str} in a {scene_str}:
295
-
296
- 1. Create a short, interesting caption (5-7 words)
297
- 2. Write a 2-3 paragraph story about what's happening in this scene
298
-
299
- Format:
300
- CAPTION: [your caption here]
301
- STORY: [your story here]"""
302
  else:
303
- prompt = f"""Describe an image containing {objects_str} in a {scene_str}:
304
-
305
- 1. Give a simple, descriptive caption
306
- 2. Write a 1-2 paragraph description
307
-
308
- Format:
309
- CAPTION: [caption here]
310
- STORY: [description here]"""
311
 
312
- # QWEN 1.8B SPECIFIC FORMATTING
313
  if "qwen" in self.llm_model_id.lower():
314
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
315
  elif "phi" in self.llm_model_id:
@@ -317,103 +296,77 @@ class ImageStoryteller:
317
  elif "gemma" in self.llm_model_id:
318
  formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
319
  else:
320
- formatted_prompt = f"{prompt}\n\n"
321
 
322
  # Tokenize and generate
323
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
324
 
325
  with torch.no_grad():
326
- if "qwen" in self.llm_model_id.lower():
327
- outputs = self.llm_model.generate(
328
- **inputs,
329
- max_new_tokens=350, # Increased for caption + story
330
- temperature=creativity_level,
331
- do_sample=True,
332
- top_p=0.9,
333
- repetition_penalty=1.1,
334
- eos_token_id=self.tokenizer.eos_token_id,
335
- pad_token_id=self.tokenizer.eos_token_id,
336
- no_repeat_ngram_size=3
337
- )
338
- else:
339
- outputs = self.llm_model.generate(
340
- **inputs,
341
- max_new_tokens=300,
342
- temperature=creativity_level,
343
- do_sample=True,
344
- top_p=0.9,
345
- repetition_penalty=1.1,
346
- pad_token_id=self.tokenizer.eos_token_id
347
- )
348
 
349
  # Decode and clean up
350
- story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
351
 
352
- # Clean up Qwen specific tokens
353
- if "qwen" in self.llm_model_id.lower():
354
- story = story.replace(formatted_prompt, "").strip()
355
- story = story.replace("<|im_end|>", "").strip()
356
- story = story.replace("<|im_start|>", "").strip()
357
- story = story.replace("<|endoftext|>", "").strip()
358
- elif story.startswith(formatted_prompt):
359
- story = story[len(formatted_prompt):].strip()
360
-
361
- # Additional cleanup
362
- story = story.strip()
363
-
364
- # Ensure proper formatting for caption and story
365
- lines = story.split('\n')
366
- formatted_lines = []
367
- for line in lines:
368
- line = line.strip()
369
- if line and not line.startswith('CAPTION:') and not line.startswith('STORY:'):
370
- # If we have caption/story markers but missing the prefix
371
- if 'caption:' in line.lower() and 'caption:' not in line:
372
- line = 'CAPTION: ' + line.split('caption:')[-1].strip()
373
- elif 'story:' in line.lower() and 'story:' not in line:
374
- line = 'STORY: ' + line.split('story:')[-1].strip()
375
- formatted_lines.append(line)
376
-
377
- story = '\n'.join(formatted_lines)
378
-
379
- # Add visual separator if not already present
380
- if 'STORY:' in story:
381
- parts = story.split('STORY:', 1)
382
- if len(parts) == 2:
383
- caption_part = parts[0].replace('CAPTION:', '').strip()
384
- story_part = parts[1].strip()
385
- # Format with separator
386
- story = f"{caption_part}\n{'─' * 40}\n{story_part}"
387
-
388
- # Fallback if generation is too short
389
- if len(story.split()) < 15:
390
- fallback_prompt = f"Create a caption and story for {objects_str} in {scene_str}."
391
- simple_inputs = self.tokenizer(fallback_prompt, return_tensors="pt").to(self.llm_model.device)
392
- with torch.no_grad():
393
- simple_outputs = self.llm_model.generate(
394
- **simple_inputs,
395
- max_new_tokens=250,
396
- temperature=0.8,
397
- do_sample=True
398
- )
399
- story = self.tokenizer.decode(simple_outputs[0], skip_special_tokens=True)
400
- story = story.replace(fallback_prompt, "").strip()
401
- # Add separator
402
- sentences = story.split('. ')
403
- if sentences:
404
- caption = sentences[0].strip()
405
- if not caption.endswith('.'):
406
- caption += '.'
407
- rest_of_story = '. '.join(sentences[1:]) if len(sentences) > 1 else story
408
- story = f"{caption}\n{'─' * 40}\n{rest_of_story}"
409
-
410
- return story
411
 
412
  except Exception as e:
413
  print(f"Story generation failed: {e}")
414
  objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
415
  scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
416
- return f"Failed to generate story. Detected objects: {objects_str} in a {scene_str}. Error: {str(e)}"
417
 
418
  def process_image_and_generate_story(self, image, creativity_level=0.7):
419
  """Complete pipeline: analyze image and generate story"""
 
273
  scenes = [scene['type'] for scene in analysis_result['scenes']]
274
 
275
  # Create a prompt for the LLM
276
+ objects_str = ", ".join(objects)
277
  scene_str = scenes[0] if scenes else "general scene"
278
 
279
+ # Convert creativity_level to float if it's a tuple
280
  if isinstance(creativity_level, (tuple, list)):
281
  creativity_level = float(creativity_level[0])
282
 
283
+ # SIMPLIFIED PROMPT - No numbered lists or complex formatting
284
  if creativity_level > 0.8:
285
+ prompt = f"Write a catchy 5-7 word YouTube-style caption, then a creative 3-4 paragraph story about {objects_str} in a {scene_str}."
 
 
 
 
 
 
 
286
  elif creativity_level > 0.5:
287
+ prompt = f"Create a short caption and a 2-3 paragraph story about {objects_str} in a {scene_str}."
 
 
 
 
 
 
 
288
  else:
289
+ prompt = f"Write a caption and a 1-2 paragraph description of {objects_str} in a {scene_str}."
 
 
 
 
 
 
 
290
 
291
+ # QWEN FORMATTING
292
  if "qwen" in self.llm_model_id.lower():
293
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
294
  elif "phi" in self.llm_model_id:
 
296
  elif "gemma" in self.llm_model_id:
297
  formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
298
  else:
299
+ formatted_prompt = f"User: {prompt}\nAssistant:"
300
 
301
  # Tokenize and generate
302
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
303
 
304
  with torch.no_grad():
305
+ outputs = self.llm_model.generate(
306
+ **inputs,
307
+ max_new_tokens=300,
308
+ temperature=creativity_level,
309
+ do_sample=True,
310
+ top_p=0.9,
311
+ repetition_penalty=1.1,
312
+ eos_token_id=self.tokenizer.eos_token_id,
313
+ pad_token_id=self.tokenizer.eos_token_id,
314
+ no_repeat_ngram_size=3
315
+ )
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  # Decode and clean up
318
+ raw_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
319
 
320
+ # Extract only the assistant's response
321
+ if "assistant" in raw_output.lower():
322
+ parts = raw_output.lower().split("assistant")
323
+ if len(parts) > 1:
324
+ story = parts[-1].strip()
325
+ else:
326
+ story = raw_output
327
+ elif "Assistant:" in raw_output:
328
+ parts = raw_output.split("Assistant:")
329
+ story = parts[-1].strip() if len(parts) > 1 else raw_output
330
+ else:
331
+ story = raw_output
332
+
333
+ # Clean Qwen tokens if present
334
+ qwen_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
335
+ for token in qwen_tokens:
336
+ story = story.replace(token, "").strip()
337
+
338
+ # Clean any remaining prompt text
339
+ story = story.replace(prompt, "").strip()
340
+
341
+ # Extract or create caption from the story
342
+ sentences = story.split('. ')
343
+ if sentences:
344
+ # Take first sentence as caption
345
+ caption = sentences[0].strip()
346
+ if not caption.endswith('.'):
347
+ caption += '.'
348
+
349
+ # Rest of the story
350
+ if len(sentences) > 1:
351
+ story_text = '. '.join(sentences[1:])
352
+ else:
353
+ story_text = story.replace(caption, "").strip()
354
+
355
+ # Format with caption at top and separator
356
+ formatted_output = f"{caption}\n{'─' * 40}\n{story_text}"
357
+ else:
358
+ formatted_output = story
359
+
360
+ # Clean up any extra whitespace
361
+ formatted_output = '\n'.join([line.strip() for line in formatted_output.split('\n') if line.strip()])
362
+
363
+ return formatted_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  except Exception as e:
366
  print(f"Story generation failed: {e}")
367
  objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
368
  scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
369
+ return f"Caption: Analysis of {objects_str}\n{'─' * 40}\nFailed to generate story. Detected: {objects_str} in {scene_str}."
370
 
371
  def process_image_and_generate_story(self, image, creativity_level=0.7):
372
  """Complete pipeline: analyze image and generate story"""