Tas01 commited on
Commit
969c67b
·
verified ·
1 Parent(s): d39ea48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -57
app.py CHANGED
@@ -8,9 +8,115 @@ import rembg
8
  from io import BytesIO
9
  import os
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class ImageStoryteller:
12
- def __init__(self):
13
- print("Initializing Image Storyteller with CLIP-ViT...")
14
 
15
  # Load CLIP model for image understanding
16
  try:
@@ -22,6 +128,26 @@ class ImageStoryteller:
22
  self.clip_model = None
23
  self.clip_processor = None
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Common objects for scene understanding
26
  self.common_objects = [
27
  'person', 'people', 'human', 'man', 'woman', 'child', 'baby',
@@ -110,46 +236,136 @@ class ImageStoryteller:
110
  return self.fallback_image_analysis(image)
111
 
112
  def fallback_image_analysis(self, image):
113
- """Fallback image analysis when CLIP fails"""
114
- img_np = np.array(image)
115
- height, width = img_np.shape[:2]
 
 
 
 
 
 
 
 
116
 
117
- # Simple color-based analysis
118
- hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
119
-
120
- objects = []
121
- scenes = []
122
-
123
- # Detect blue areas (sky/water)
124
- blue_mask = cv2.inRange(hsv, (100, 50, 50), (130, 255, 255))
125
- if np.sum(blue_mask) > height * width * 0.1:
126
- objects.append({'name': 'sky', 'confidence': 0.6})
127
- scenes.append({'type': 'outdoor scene', 'confidence': 0.7})
128
-
129
- # Detect green areas (nature)
130
- green_mask = cv2.inRange(hsv, (35, 50, 50), (85, 255, 255))
131
- if np.sum(green_mask) > height * width * 0.1:
132
- objects.append({'name': 'nature', 'confidence': 0.6})
133
- scenes.append({'type': 'nature', 'confidence': 0.7})
134
-
135
- # Detect skin tones (people)
136
- skin_mask = cv2.inRange(hsv, (0, 30, 60), (20, 150, 255))
137
- if np.sum(skin_mask) > 1000:
138
- objects.append({'name': 'person', 'confidence': 0.5})
139
- scenes.append({'type': 'portrait', 'confidence': 0.6})
140
-
141
- # Detect edges (buildings/structures)
142
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
143
- edges = cv2.Canny(gray, 50, 150)
144
- if np.sum(edges) > height * width * 0.05:
145
- objects.append({'name': 'building', 'confidence': 0.5})
146
- scenes.append({'type': 'urban', 'confidence': 0.6})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  return {
149
- 'objects': objects,
150
- 'scenes': scenes,
151
- 'success': False
 
152
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def create_analysis_overlay(self, image, analysis_result):
155
  """Create analysis overlay in bottom left with white text on black background"""
@@ -201,28 +417,28 @@ class ImageStoryteller:
201
 
202
  return Image.fromarray(overlay)
203
 
204
- def generate_story(self, analysis_result, image_size):
205
- """Generate story based on CLIP analysis"""
206
- # Prepare context from analysis
207
- objects_text = ", ".join([obj['name'] for obj in analysis_result['objects'][:5]])
208
- scenes_text = analysis_result['scenes'][0]['type'] if analysis_result['scenes'] else "unknown scene"
209
 
210
- width, height = image_size
211
 
212
- # Create story based on analysis
213
- if 'person' in objects_text.lower():
214
- story = f"In this captivating {width}x{height} {scenes_text}, we see {objects_text}. A story unfolds where human presence meets the environment, creating moments of connection and experience that speak to the heart of what it means to be alive in this visual narrative."
215
 
216
- elif 'nature' in objects_text.lower():
217
- story = f"This breathtaking {width}x{height} {scenes_text} reveals {objects_text}. Nature's timeless beauty tells a story of growth, change, and the enduring power of the natural world, where every element harmonizes to create a symphony of visual poetry."
218
 
219
- elif 'building' in objects_text.lower() or 'urban' in scenes_text.lower():
220
- story = f"Architectural elegance defines this {width}x{height} {scenes_text} featuring {objects_text}. The structures stand as silent witnesses to countless stories, their forms telling tales of human ingenuity, community, and the relentless march of progress through time."
221
 
222
- else:
223
- story = f"In this compelling {width}x{height} composition showing {objects_text}, visual elements converge to create a unique narrative. The {scenes_text} invites contemplation, asking viewers to explore the relationships between forms, colors, and spaces that together tell a story beyond words."
224
 
225
- return story
226
 
227
  # def create_story_overlay(self, image, story):
228
  # """Create story overlay in bottom left with white text on black background"""
@@ -435,9 +651,9 @@ def load_selected_example(evt: gr.SelectData):
435
  return None
436
 
437
  # Create Gradio interface
438
- with gr.Blocks(title="CLIP-ViT Image Analyzer", theme=gr.themes.Soft()) as demo:
439
- gr.Markdown("# 🎨 CLIP-ViT Image Analyzer")
440
- gr.Markdown("**Upload an image to analyze content and generate stories**")
441
 
442
  # Load example images
443
  example_images_list = get_example_images()
 
8
  from io import BytesIO
9
  import os
10
 
11
+ # class ImageStoryteller:
12
+ # def __init__(self):
13
+ # print("Initializing Image Storyteller with CLIP-ViT...")
14
+
15
+ # # Load CLIP model for image understanding
16
+ # try:
17
+ # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
18
+ # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
19
+ # print("CLIP-ViT model loaded successfully!")
20
+ # except Exception as e:
21
+ # print(f"CLIP loading failed: {e}")
22
+ # self.clip_model = None
23
+ # self.clip_processor = None
24
+
25
+ # # Common objects for scene understanding
26
+ # self.common_objects = [
27
+ # 'person', 'people', 'human', 'man', 'woman', 'child', 'baby',
28
+ # 'dog', 'cat', 'animal', 'bird', 'horse', 'cow', 'sheep',
29
+ # 'car', 'vehicle', 'bus', 'truck', 'bicycle', 'motorcycle',
30
+ # 'building', 'house', 'skyscraper', 'architecture',
31
+ # 'tree', 'forest', 'nature', 'mountain', 'sky', 'clouds',
32
+ # 'water', 'ocean', 'river', 'lake', 'beach',
33
+ # 'food', 'fruit', 'vegetable', 'meal',
34
+ # 'indoor', 'outdoor', 'urban', 'rural'
35
+ # ]
36
+
37
+ # # Scene categories for classification
38
+ # self.scene_categories = [
39
+ # "portrait", "landscape", "cityscape", "indoor scene", "outdoor scene",
40
+ # "nature", "urban", "beach", "mountain", "forest", "street",
41
+ # "party", "celebration", "sports", "action", "still life",
42
+ # "abstract", "art", "architecture", "wildlife", "pet"
43
+ # ]
44
+
45
+ # def analyze_image_with_clip(self, image):
46
+ # """Analyze image using CLIP to understand content and scene"""
47
+ # if self.clip_model is None or self.clip_processor is None:
48
+ # return self.fallback_image_analysis(image)
49
+
50
+ # try:
51
+ # # Convert PIL to RGB
52
+ # image_rgb = image.convert('RGB')
53
+
54
+ # # Analyze objects in the image
55
+ # object_inputs = self.clip_processor(
56
+ # text=self.common_objects,
57
+ # images=image_rgb,
58
+ # return_tensors="pt",
59
+ # padding=True
60
+ # )
61
+
62
+ # with torch.no_grad():
63
+ # object_outputs = self.clip_model(**object_inputs)
64
+ # object_logits = object_outputs.logits_per_image
65
+ # object_probs = object_logits.softmax(dim=1)
66
+
67
+ # # Get top objects
68
+ # top_object_indices = torch.topk(object_probs, 5, dim=1).indices[0]
69
+ # detected_objects = []
70
+ # for idx in top_object_indices:
71
+ # obj_name = self.common_objects[idx]
72
+ # confidence = object_probs[0][idx].item()
73
+ # if confidence > 0.1: # Confidence threshold
74
+ # detected_objects.append({
75
+ # 'name': obj_name,
76
+ # 'confidence': confidence
77
+ # })
78
+
79
+ # # Analyze scene type
80
+ # scene_inputs = self.clip_processor(
81
+ # text=self.scene_categories,
82
+ # images=image_rgb,
83
+ # return_tensors="pt",
84
+ # padding=True
85
+ # )
86
+
87
+ # with torch.no_grad():
88
+ # scene_outputs = self.clip_model(**scene_inputs)
89
+ # scene_logits = scene_outputs.logits_per_image
90
+ # scene_probs = scene_logits.softmax(dim=1)
91
+
92
+ # top_scene_indices = torch.topk(scene_probs, 3, dim=1).indices[0]
93
+ # scene_types = []
94
+ # for idx in top_scene_indices:
95
+ # scene_name = self.scene_categories[idx]
96
+ # confidence = scene_probs[0][idx].item()
97
+ # scene_types.append({
98
+ # 'type': scene_name,
99
+ # 'confidence': confidence
100
+ # })
101
+
102
+ # return {
103
+ # 'objects': detected_objects,
104
+ # 'scenes': scene_types,
105
+ # 'success': True
106
+ # }
107
+
108
+ # except Exception as e:
109
+ # print(f"CLIP analysis failed: {e}")
110
+ # return self.fallback_image_analysis(image)
111
+ import torch
112
+ from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM
113
+ from huggingface_hub import login
114
+ from PIL import Image
115
+ import gradio as gr
116
+
117
  class ImageStoryteller:
118
+ def __init__(self, llm_model_id="microsoft/phi-2"):
119
+ print("Initializing Image Storyteller with CLIP-ViT and LLM...")
120
 
121
  # Load CLIP model for image understanding
122
  try:
 
128
  self.clip_model = None
129
  self.clip_processor = None
130
 
131
+ # Load LLM for story generation
132
+ try:
133
+ # For Gemma, you need to login first (uncomment if using Gemma)
134
+ # login() # Only for Gemma models
135
+
136
+ # Choose your LLM (phi-2 doesn't require login)
137
+ self.llm_model_id = llm_model_id
138
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
139
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
140
+ llm_model_id,
141
+ torch_dtype=torch.float16,
142
+ device_map="auto",
143
+ trust_remote_code=True if "phi" in llm_model_id else False
144
+ )
145
+ print(f"LLM model {llm_model_id} loaded successfully!")
146
+ except Exception as e:
147
+ print(f"LLM loading failed: {e}")
148
+ self.llm_model = None
149
+ self.tokenizer = None
150
+
151
  # Common objects for scene understanding
152
  self.common_objects = [
153
  'person', 'people', 'human', 'man', 'woman', 'child', 'baby',
 
236
  return self.fallback_image_analysis(image)
237
 
238
  def fallback_image_analysis(self, image):
239
+ """Fallback analysis when CLIP fails"""
240
+ return {
241
+ 'objects': [{'name': 'scene', 'confidence': 1.0}],
242
+ 'scenes': [{'type': 'general image', 'confidence': 1.0}],
243
+ 'success': False
244
+ }
245
+
246
+ def generate_story_from_analysis(self, analysis_result, creativity_level=0.7):
247
+ """Generate a story based on detected objects and scene"""
248
+ if self.llm_model is None:
249
+ return "Story generation model not available."
250
 
251
+ try:
252
+ # Extract detected objects and scene
253
+ objects = [obj['name'] for obj in analysis_result['objects']]
254
+ scenes = [scene['type'] for scene in analysis_result['scenes']]
255
+
256
+ # Create a prompt for the LLM
257
+ objects_str = ", ".join(objects[:3]) # Use top 3 objects
258
+ scene_str = scenes[0] if scenes else "general scene"
259
+
260
+ # Different prompt templates for creativity
261
+ if creativity_level > 0.8:
262
+ prompt = f"""Based on this image containing {objects_str} in a {scene_str}, write a creative and imaginative short story (3-4 paragraphs).
263
+ Make it engaging and add interesting details about the scene."""
264
+ elif creativity_level > 0.5:
265
+ prompt = f"""Create a short story about an image with {objects_str} in a {scene_str}.
266
+ Write 2-3 paragraphs that describe what might be happening in this scene."""
267
+ else:
268
+ prompt = f"""Describe what you see in an image containing {objects_str} in a {scene_str}.
269
+ Write a simple 1-2 paragraph description."""
270
+
271
+ # Format for the specific LLM
272
+ if "phi" in self.llm_model_id:
273
+ # Phi-2 specific formatting
274
+ formatted_prompt = f"Instruct: {prompt}\nOutput:"
275
+ elif "gemma" in self.llm_model_id:
276
+ # Gemma specific formatting
277
+ formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
278
+ else:
279
+ # Generic formatting
280
+ formatted_prompt = f"Write a story: {prompt}\n\nStory:"
281
+
282
+ # Tokenize and generate
283
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
284
+
285
+ with torch.no_grad():
286
+ outputs = self.llm_model.generate(
287
+ **inputs,
288
+ max_new_tokens=250, # Shorter for faster generation
289
+ temperature=creativity_level,
290
+ do_sample=True,
291
+ top_p=0.9,
292
+ repetition_penalty=1.1,
293
+ pad_token_id=self.tokenizer.eos_token_id
294
+ )
295
+
296
+ # Decode and clean up
297
+ story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
298
+
299
+ # Remove the prompt from the beginning if present
300
+ if story.startswith(formatted_prompt):
301
+ story = story[len(formatted_prompt):].strip()
302
+
303
+ return story
304
+
305
+ except Exception as e:
306
+ print(f"Story generation failed: {e}")
307
+ return f"Failed to generate story. Detected objects: {objects_str} in a {scene_str}."
308
+
309
+ def process_image_and_generate_story(self, image, creativity_level=0.7):
310
+ """Complete pipeline: analyze image and generate story"""
311
+ print("Analyzing image...")
312
+ analysis = self.analyze_image_with_clip(image)
313
+
314
+ print("Generating story...")
315
+ story = self.generate_story_from_analysis(analysis, creativity_level)
316
+
317
+ # Return both analysis and story
318
+ detected_objects = [obj['name'] for obj in analysis['objects']]
319
+ scene_type = analysis['scenes'][0]['type'] if analysis['scenes'] else "unknown"
320
 
321
  return {
322
+ 'detected_objects': detected_objects,
323
+ 'scene_type': scene_type,
324
+ 'story': story,
325
+ 'analysis_success': analysis['success']
326
  }
327
+
328
+ # def fallback_image_analysis(self, image):
329
+ # """Fallback image analysis when CLIP fails"""
330
+ # img_np = np.array(image)
331
+ # height, width = img_np.shape[:2]
332
+
333
+ # # Simple color-based analysis
334
+ # hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
335
+
336
+ # objects = []
337
+ # scenes = []
338
+
339
+ # # Detect blue areas (sky/water)
340
+ # blue_mask = cv2.inRange(hsv, (100, 50, 50), (130, 255, 255))
341
+ # if np.sum(blue_mask) > height * width * 0.1:
342
+ # objects.append({'name': 'sky', 'confidence': 0.6})
343
+ # scenes.append({'type': 'outdoor scene', 'confidence': 0.7})
344
+
345
+ # # Detect green areas (nature)
346
+ # green_mask = cv2.inRange(hsv, (35, 50, 50), (85, 255, 255))
347
+ # if np.sum(green_mask) > height * width * 0.1:
348
+ # objects.append({'name': 'nature', 'confidence': 0.6})
349
+ # scenes.append({'type': 'nature', 'confidence': 0.7})
350
+
351
+ # # Detect skin tones (people)
352
+ # skin_mask = cv2.inRange(hsv, (0, 30, 60), (20, 150, 255))
353
+ # if np.sum(skin_mask) > 1000:
354
+ # objects.append({'name': 'person', 'confidence': 0.5})
355
+ # scenes.append({'type': 'portrait', 'confidence': 0.6})
356
+
357
+ # # Detect edges (buildings/structures)
358
+ # gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
359
+ # edges = cv2.Canny(gray, 50, 150)
360
+ # if np.sum(edges) > height * width * 0.05:
361
+ # objects.append({'name': 'building', 'confidence': 0.5})
362
+ # scenes.append({'type': 'urban', 'confidence': 0.6})
363
+
364
+ # return {
365
+ # 'objects': objects,
366
+ # 'scenes': scenes,
367
+ # 'success': False
368
+ # }
369
 
370
  def create_analysis_overlay(self, image, analysis_result):
371
  """Create analysis overlay in bottom left with white text on black background"""
 
417
 
418
  return Image.fromarray(overlay)
419
 
420
+ # def generate_story(self, analysis_result, image_size):
421
+ # """Generate story based on CLIP analysis"""
422
+ # # Prepare context from analysis
423
+ # objects_text = ", ".join([obj['name'] for obj in analysis_result['objects'][:5]])
424
+ # scenes_text = analysis_result['scenes'][0]['type'] if analysis_result['scenes'] else "unknown scene"
425
 
426
+ # width, height = image_size
427
 
428
+ # # Create story based on analysis
429
+ # if 'person' in objects_text.lower():
430
+ # story = f"In this captivating {width}x{height} {scenes_text}, we see {objects_text}. A story unfolds where human presence meets the environment, creating moments of connection and experience that speak to the heart of what it means to be alive in this visual narrative."
431
 
432
+ # elif 'nature' in objects_text.lower():
433
+ # story = f"This breathtaking {width}x{height} {scenes_text} reveals {objects_text}. Nature's timeless beauty tells a story of growth, change, and the enduring power of the natural world, where every element harmonizes to create a symphony of visual poetry."
434
 
435
+ # elif 'building' in objects_text.lower() or 'urban' in scenes_text.lower():
436
+ # story = f"Architectural elegance defines this {width}x{height} {scenes_text} featuring {objects_text}. The structures stand as silent witnesses to countless stories, their forms telling tales of human ingenuity, community, and the relentless march of progress through time."
437
 
438
+ # else:
439
+ # story = f"In this compelling {width}x{height} composition showing {objects_text}, visual elements converge to create a unique narrative. The {scenes_text} invites contemplation, asking viewers to explore the relationships between forms, colors, and spaces that together tell a story beyond words."
440
 
441
+ # return story
442
 
443
  # def create_story_overlay(self, image, story):
444
  # """Create story overlay in bottom left with white text on black background"""
 
651
  return None
652
 
653
  # Create Gradio interface
654
+ with gr.Blocks(title="Who says AI isn’t creative? Watch it turn a single image into a beautifully written story", theme=gr.themes.Soft()) as demo:
655
+ gr.Markdown("# Image Story Teller")
656
+ gr.Markdown("**Upload an image to analyse content and generate stories**")
657
 
658
  # Load example images
659
  example_images_list = get_example_images()