File size: 21,655 Bytes
0b29680
6b076b4
 
0b29680
8f01c45
fa328e2
 
 
2b56642
0b29680
523f9c6
969c67b
 
 
 
 
 
8f01c45
a17ab41
26e0765
969c67b
a6d1b34
fa328e2
a6d1b34
fa328e2
 
 
a6d1b34
fa328e2
 
 
8f01c45
969c67b
 
 
 
 
 
 
 
 
531f90e
3355e27
 
531f90e
 
969c67b
 
 
 
 
 
fa328e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddfeeb4
fa328e2
 
 
41470df
2b56642
41470df
fa328e2
41470df
2b56642
fa328e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b56642
41470df
fa328e2
 
 
ddfeeb4
fa328e2
 
 
 
 
 
 
 
 
2b56642
41470df
fa328e2
 
41470df
 
2b56642
41470df
fa328e2
41470df
2b56642
41470df
969c67b
 
 
 
 
 
 
bc37146
 
7619364
bc37146
26e0765
 
e5da35e
26e0765
 
 
 
 
 
4888ca6
26e0765
 
4888ca6
26e0765
 
 
4888ca6
26e0765
4888ca6
26e0765
4888ca6
26e0765
4888ca6
26e0765
4888ca6
e5da35e
26e0765
bc37146
26e0765
 
 
7619364
4888ca6
26e0765
 
 
83039c5
 
4888ca6
 
 
 
 
 
 
 
 
 
 
26e0765
 
4888ca6
26e0765
4888ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26e0765
 
 
 
 
4888ca6
969c67b
 
 
 
 
 
 
d3ca8d7
969c67b
 
 
 
8f01c45
6755882
7ef80a5
 
1fa30cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6755882
c95b9e0
fa328e2
 
 
41470df
fa328e2
 
 
 
 
 
 
 
 
 
 
 
 
41470df
fa328e2
 
fcc2ffb
fa328e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d1b34
fa328e2
 
8f01c45
fa328e2
5c067a7
8f01c45
fa328e2
c95b9e0
8f01c45
d3ca8d7
 
8f01c45
fa328e2
 
2b56642
ea4b534
8f01c45
 
1c598e1
 
fa328e2
 
b9319b8
8f01c45
 
b9319b8
ebcc9b0
 
 
 
06911ea
ebcc9b0
 
4a6b393
 
 
 
 
ebcc9b0
 
 
4a6b393
 
ebcc9b0
 
4a6b393
 
5d8408b
4a6b393
 
 
 
 
8f01c45
969c67b
 
 
6b076b4
4a6b393
 
 
c95b9e0
4d4ee4b
e6d6c51
4d4ee4b
b0580c0
4d4ee4b
 
 
b0580c0
4d4ee4b
b0580c0
4d4ee4b
 
e6d6c51
4d4ee4b
 
 
 
 
 
 
e6d6c51
 
 
 
 
4d4ee4b
e6d6c51
 
c95b9e0
4d4ee4b
 
 
85544ec
 
 
 
 
 
 
 
 
4d4ee4b
85544ec
4d4ee4b
85544ec
 
 
 
4d4ee4b
 
 
 
 
 
8f01c45
 
 
 
 
fa328e2
ddfeeb4
08f4d2f
 
 
ea4b534
08f4d2f
 
 
ebcc9b0
08f4d2f
e6d6c51
ebcc9b0
4a6b393
 
ebcc9b0
 
 
172eff5
4a6b393
 
 
 
e6d6c51
 
 
ebcc9b0
8f01c45
2b56642
d87b7cb
ea4b534
7ef80a5
d87b7cb
 
 
 
 
 
 
 
2b56642
 
e927fb8
 
 
 
 
 
 
fa328e2
 
 
 
 
 
 
 
 
1c598e1
8f01c45
1c598e1
fa328e2
 
 
 
 
1c598e1
 
08f4d2f
 
 
 
8f01c45
aa8fd16
 
fa328e2
ea4b534
d722a98
fa328e2
 
08f4d2f
 
 
 
ea4b534
08f4d2f
 
4a6b393
ebcc9b0
4a6b393
ebcc9b0
 
 
 
fa328e2
 
 
 
 
 
 
 
 
 
 
 
aa8fd16
b9319b8
8f01c45
 
aa8fd16
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import rembg
from io import BytesIO
import os


import torch
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from PIL import Image
import gradio as gr

class ImageStoryteller:
    # def __init__(self, llm_model_id="microsoft/phi-2"): microsoft/phi-3-mini-4k-instruct
    def __init__(self, llm_model_id="Qwen/Qwen1.5-1.8B-Chat"):
        print("Initializing Image Storyteller with CLIP-ViT and LLM...")
        
        # Load CLIP model for image understanding
        try:
            self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            print("CLIP-ViT model loaded successfully!")
        except Exception as e:
            print(f"CLIP loading failed: {e}")
            self.clip_model = None
            self.clip_processor = None
        
        # Load LLM for story generation
        try:
            # Choose your LLM (phi-2 doesn't require login)
            self.llm_model_id = llm_model_id
            self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                llm_model_id,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True)
                # trust_remote_code=True if "phi" in llm_model_id else False
                # To this (for Qwen and other models):
                # trust_remote_code=True if any(keyword in llm_model_id.lower() for keyword in ["phi", "qwen", "yi", "deepseek"]) else False

            print(f"LLM model {llm_model_id} loaded successfully!")
        except Exception as e:
            print(f"LLM loading failed: {e}")
            self.llm_model = None
            self.tokenizer = None
        
        # Common objects for scene understanding
        self.common_objects = [
            'person', 'people', 'human', 'man', 'woman', 'child', 'baby',
            'dog', 'cat', 'animal', 'bird', 'horse', 'cow', 'sheep',
            'car', 'vehicle', 'bus', 'truck', 'bicycle', 'motorcycle',
            'building', 'house', 'skyscraper', 'architecture',
            'tree', 'forest', 'nature', 'mountain', 'sky', 'clouds',
            'water', 'ocean', 'river', 'lake', 'beach',
            'food', 'fruit', 'vegetable', 'meal',
            'indoor', 'outdoor', 'urban', 'rural'
        ]
        
        # Scene categories for classification
        self.scene_categories = [
            "portrait", "landscape", "cityscape", "indoor scene", "outdoor scene",
            "nature", "urban", "beach", "mountain", "forest", "street",
            "party", "celebration", "sports", "action", "still life",
            "abstract", "art", "architecture", "wildlife", "pet"
        ]
    
    def analyze_image_with_clip(self, image):
        """Analyze image using CLIP to understand content and scene"""
        if self.clip_model is None or self.clip_processor is None:
            return self.fallback_image_analysis(image)
        
        try:
            # Convert PIL to RGB
            image_rgb = image.convert('RGB')
            
            # Analyze objects in the image
            object_inputs = self.clip_processor(
                text=self.common_objects, 
                images=image_rgb, 
                return_tensors="pt", 
                padding=True
            )
            
            with torch.no_grad():
                object_outputs = self.clip_model(**object_inputs)
                object_logits = object_outputs.logits_per_image
                object_probs = object_logits.softmax(dim=1)
            
            # Get top objects
            top_object_indices = torch.topk(object_probs, 5, dim=1).indices[0]
            detected_objects = []
            for idx in top_object_indices:
                obj_name = self.common_objects[idx]
                confidence = object_probs[0][idx].item()
                if confidence > 0.1:  # Confidence threshold
                    detected_objects.append({
                        'name': obj_name,
                        'confidence': confidence
                    })
            
            # Analyze scene type
            scene_inputs = self.clip_processor(
                text=self.scene_categories,
                images=image_rgb,
                return_tensors="pt",
                padding=True
            )
            
            with torch.no_grad():
                scene_outputs = self.clip_model(**scene_inputs)
                scene_logits = scene_outputs.logits_per_image
                scene_probs = scene_logits.softmax(dim=1)
            
            top_scene_indices = torch.topk(scene_probs, 3, dim=1).indices[0]
            scene_types = []
            for idx in top_scene_indices:
                scene_name = self.scene_categories[idx]
                confidence = scene_probs[0][idx].item()
                scene_types.append({
                    'type': scene_name,
                    'confidence': confidence
                })
            
            return {
                'objects': detected_objects,
                'scenes': scene_types,
                'success': True
            }
            
        except Exception as e:
            print(f"CLIP analysis failed: {e}")
            return self.fallback_image_analysis(image)
    
    def fallback_image_analysis(self, image):
        """Fallback analysis when CLIP fails"""
        return {
            'objects': [{'name': 'scene', 'confidence': 1.0}],
            'scenes': [{'type': 'general image', 'confidence': 1.0}],
            'success': False
        }
    
    

    def generate_story(self, analysis_result, creativity_level=0.7):
        """Generate a story with caption based on detected objects and scene using Qwen"""
        if self.llm_model is None:
            return "Story generation model not available."
        
        try:
            # Extract detected objects and scene
            objects = [obj['name'] for obj in analysis_result['objects']]
            scenes = [scene['type'] for scene in analysis_result['scenes']]
            
            # Create a prompt for the LLM
            objects_str = ", ".join(objects)
            scene_str = scenes[0] if scenes else "general scene"
            
            # Convert creativity_level to float if it's a tuple
            if isinstance(creativity_level, (tuple, list)):
                creativity_level = float(creativity_level[0])
            
            # SIMPLIFIED PROMPT - No numbered lists or complex formatting
            if creativity_level > 0.8:
                prompt = f"Write a catchy 5-7 word YouTube-style caption, then a creative 3-4 paragraph story about {objects_str} in a {scene_str}."
            elif creativity_level > 0.5:
                prompt = f"Create a short caption and a 2-3 paragraph story about {objects_str} in a {scene_str}."
            else:
                prompt = f"Write a caption and a 1-2 paragraph description of {objects_str} in a {scene_str}."
            
            # QWEN FORMATTING
            if "qwen" in self.llm_model_id.lower():
                formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
            elif "phi" in self.llm_model_id:
                formatted_prompt = f"Instruct: {prompt}\nOutput:"
            elif "gemma" in self.llm_model_id:
                formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
            else:
                formatted_prompt = f"User: {prompt}\nAssistant:"
            
            # Tokenize and generate
            inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
            
            with torch.no_grad():
                outputs = self.llm_model.generate(
                    **inputs,
                    max_new_tokens=300,
                    temperature=creativity_level,
                    do_sample=True,
                    top_p=0.9,
                    repetition_penalty=1.1,
                    eos_token_id=self.tokenizer.eos_token_id,
                    pad_token_id=self.tokenizer.eos_token_id,
                    no_repeat_ngram_size=3
                )
            
            # Decode and clean up
            raw_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract only the assistant's response
            if "assistant" in raw_output.lower():
                parts = raw_output.lower().split("assistant")
                if len(parts) > 1:
                    story = parts[-1].strip()
                else:
                    story = raw_output
            elif "Assistant:" in raw_output:
                parts = raw_output.split("Assistant:")
                story = parts[-1].strip() if len(parts) > 1 else raw_output
            else:
                story = raw_output
            
            # Clean Qwen tokens if present
            qwen_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
            for token in qwen_tokens:
                story = story.replace(token, "").strip()
            
            # Clean any remaining prompt text
            story = story.replace(prompt, "").strip()
            
            # Extract or create caption from the story
            sentences = story.split('. ')
            if sentences:
                # Take first sentence as caption
                caption = sentences[0].strip()
                if not caption.endswith('.'):
                    caption += '.'
                
                # Rest of the story
                if len(sentences) > 1:
                    story_text = '. '.join(sentences[1:])
                else:
                    story_text = story.replace(caption, "").strip()
                
                # Format with caption at top and separator
                formatted_output = f"{caption}\n{'─' * 40}\n{story_text}"
            else:
                formatted_output = story
            
            # Clean up any extra whitespace
            formatted_output = '\n'.join([line.strip() for line in formatted_output.split('\n') if line.strip()])
            
            return formatted_output
            
        except Exception as e:
            print(f"Story generation failed: {e}")
            objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
            scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
            return f"Caption: Analysis of {objects_str}\n{'─' * 40}\nFailed to generate story. Detected: {objects_str} in {scene_str}."
    
    def process_image_and_generate_story(self, image, creativity_level=0.7):
        """Complete pipeline: analyze image and generate story"""
        print("Analyzing image...")
        analysis = self.analyze_image_with_clip(image)
        
        print("Generating story...")
        story = self.generate_story(analysis, creativity_level)
        
        # Return both analysis and story
        detected_objects = [obj['name'] for obj in analysis['objects']]
        scene_type = analysis['scenes'][0]['type'] if analysis['scenes'] else "unknown"
        
        return story, detected_objects, scene_type

    def create_story_overlay(self, image, story):
        """Create formatted text with caption and story for textbox display"""
        
        # Generate caption (first sentence of the story)
        caption = ""
        sentences = story.split('. ')
        if sentences:
            caption = sentences[0].strip()
            if not caption.endswith('.'):
                caption += '.'
        
        # Format the text with caption separated from story
        # Using a separator line of dashes
        separator = "─" * 40
        
        # Format the complete text for the textbox
        formatted_text = f"{caption}\n{separator}\n{story}"
        
        return formatted_text
    
    
    
    def remove_background(self, image):
        """Remove background using rembg"""
        try:
            # Convert PIL image to bytes
            img_byte_arr = BytesIO()
            image.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()
            
            # Remove background
            output = rembg.remove(img_byte_arr)
            
            # Convert back to PIL Image
            result_image = Image.open(BytesIO(output))
            
            return result_image
            
        except Exception as e:
            print(f"Background removal failed: {e}")
            return image
    
    def remove_foreground(self, image):
        """Remove foreground and keep only background using inpainting"""
        try:
            # First remove background to get foreground mask
            img_byte_arr = BytesIO()
            image.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()
            
            # Remove background to get alpha channel
            output = rembg.remove(img_byte_arr)
            foreground_image = Image.open(BytesIO(output))
            
            # Convert to numpy arrays
            original_np = np.array(image.convert('RGB'))
            foreground_np = np.array(foreground_image.convert('RGBA'))
            
            # Create mask where foreground exists (alpha > 0)
            mask = foreground_np[:, :, 3] > 0
            
            # Create background-only image by filling foreground areas
            background_np = original_np.copy()
            
            # Simple inpainting: fill foreground areas with average background color
            # Calculate average background color from areas without foreground
            bg_pixels = original_np[~mask]
            if len(bg_pixels) > 0:
                avg_color = np.mean(bg_pixels, axis=0)
                background_np[mask] = avg_color.astype(np.uint8)
            
            return Image.fromarray(background_np)
            
        except Exception as e:
            print(f"Foreground removal failed: {e}")
            return image
    
    def process_image(self, image):
        """Main processing function"""
        try:
            # Analyze image with CLIP-ViT
            analysis_result = self.analyze_image_with_clip(image)
            
            # Generate story
            story = self.generate_story(analysis_result, creativity_level=0.7)
            
            # # Create analysis overlay
            # analysis_image = self.create_analysis_overlay(image, analysis_result)
            
            # Create story overlay
            story_image = self.create_story_overlay(image, story)
            
            return story_image
            
        except Exception as e:
            error_msg = f"An error occurred: {str(e)}"
            print(error_msg)
            # Return original images on error
            return image, image

# Initialize the storyteller
storyteller = ImageStoryteller()

# Get example images from local directory
def get_example_images():
    """Get example images from local directory"""
    example_images = []
    for i in range(1, 17):
        img_path = f"obj_{i:02d}.jpg"
        if os.path.exists(img_path):
            # Load and resize the image for the gallery
            img = Image.open(img_path)
            # Resize to smaller size for gallery display
            img.thumbnail((150, 150))
            example_images.append(img)
        else:
            print(f"Warning: {img_path} not found")
            # Create a simple placeholder image
            placeholder = Image.new('RGB', (150, 150), color=(73, 109, 137))
            example_images.append(placeholder)
    return example_images

def load_selected_example(evt: gr.SelectData):
    """Load the full-size version of the selected example image"""
    if evt.index < 16:  # We have 8 example images
        img_path = f"obj_{evt.index+1:02d}.jpg"
        if os.path.exists(img_path):
            return Image.open(img_path)
    return None

# Create Gradio interface
with gr.Blocks(title="Who says AI isn’t creative? Watch it turn a single image into a beautifully written story", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Image Story Teller")
    gr.Markdown("**Upload an image to analyse content and generate stories**")
    
    # Load example images
    example_images_list = get_example_images()
    
    

    custom_css = """
    <style>
    .gradio-container {
        height: 100vh !important;
        max-height: 100vh !important;
        overflow: hidden !important;
    }
    
    #blocks-container {
        height: calc(100vh - 100px) !important;
        overflow-y: auto !important;
    }
    
    /* Remove gallery selection frames */
    .gallery .wrap.contain .grid .wrap,
    .gallery .wrap.contain .grid .wrap.selected,
    .gallery .thumbnail,
    .gallery .thumbnail.selected,
    .gallery .wrap.gradio-image,
    .gallery .wrap.gradio-image.selected {
        border: none !important;
        box-shadow: none !important;
        outline: none !important;
    }
    </style>
    """
    
    
    javascript = """
    <script>
    document.addEventListener('DOMContentLoaded', function() {
    const stopExpansion = function() {
        // More aggressive containment
        document.body.style.maxHeight = '100vh';
        document.body.style.overflow = 'hidden';
        
        const containers = document.querySelectorAll('div');
        containers.forEach(container => {
            if (container.scrollHeight > window.innerHeight) {
                container.style.maxHeight = '95vh';
                container.style.overflowY = 'auto';
            }
        });
    };
    
    stopExpansion();
    setInterval(stopExpansion, 1000); // Keep checking every second
    });
    </script>
    """


    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                type="pil", 
                label="πŸ–ΌοΈ Upload Your Image",
                height=400
            )
            
            # Buttons row
            with gr.Row():
                process_btn = gr.Button("✨ Generate Story", variant="primary", size="lg")
                clear_btn = gr.Button("πŸ—‘οΈ Clear Image", variant="secondary", size="lg")
            
            # Example images section
            gr.Markdown("### πŸ“Έ Example Images (Click to load)")
            
            # Display example images in a gallery with custom CSS to remove frames
            example_gallery = gr.Gallery(
                value=example_images_list,
                label="",
                columns=4,
                rows=2,
                height="auto",
                scale=2,
                object_fit="contain",
                show_label=False,
                show_download_button=False,
                container=True,
                preview=False,
                allow_preview=False,
                elem_id="example-gallery"
            )
        
        with gr.Column():
            story_output = gr.Textbox(
                label="πŸ” Story",
                # height=None,
                # show_download_button=True

                lines=10,
                max_lines=20,
                show_copy_button=True,
                interactive=False,
                autoscroll=False
                
            )
    
    # with gr.Row():
    #     with gr.Column():
    #         story_output = gr.Image(
    #             label="πŸ“– Story",
    #             height=400,
    #             show_download_button=True
    #         )
    
    # Background removal section
    with gr.Row():
        with gr.Column():
            bg_remove_btn = gr.Button("🎯 Remove Background", variant="secondary", size="lg")
            background_output = gr.Image(
                label="Background Removed",
                height=400,
                show_download_button=True
            )
        
        with gr.Column():
            fg_remove_btn = gr.Button("🎯 Remove Foreground", variant="secondary", size="lg")
            foreground_output = gr.Image(
                label="Foreground Removed",
                height=400,
                show_download_button=True
            )
    
    def clear_all():
        """Clear all images and outputs"""
        return None, None, None, None, None
    
    # Set up the processing
    process_btn.click(
        fn=storyteller.process_image,
        inputs=input_image,
        outputs=[story_output]
        
    )
    
    # Clear button functionality
    clear_btn.click(
        fn=clear_all,
        inputs=[],
        outputs=[input_image, story_output, background_output, foreground_output]
    )
    
    # Example gallery selection - load full-size image when clicked
    example_gallery.select(
        fn=load_selected_example,
        inputs=[],
        outputs=input_image
    )
    
    # Background removal
    bg_remove_btn.click(
        fn=storyteller.remove_background,
        inputs=input_image,
        outputs=background_output
    )
    
    # Foreground removal
    fg_remove_btn.click(
        fn=storyteller.remove_foreground,
        inputs=input_image,
        outputs=foreground_output
    )

# Launch the application
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )