Spaces:
Running
Running
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import rembg | |
| from io import BytesIO | |
| import os | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import login | |
| from PIL import Image | |
| import gradio as gr | |
| class ImageStoryteller: | |
| # def __init__(self, llm_model_id="microsoft/phi-2"): microsoft/phi-3-mini-4k-instruct | |
| def __init__(self, llm_model_id="Qwen/Qwen1.5-1.8B-Chat"): | |
| print("Initializing Image Storyteller with CLIP-ViT and LLM...") | |
| # Load CLIP model for image understanding | |
| try: | |
| self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| print("CLIP-ViT model loaded successfully!") | |
| except Exception as e: | |
| print(f"CLIP loading failed: {e}") | |
| self.clip_model = None | |
| self.clip_processor = None | |
| # Load LLM for story generation | |
| try: | |
| # Choose your LLM (phi-2 doesn't require login) | |
| self.llm_model_id = llm_model_id | |
| self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id) | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True) | |
| # trust_remote_code=True if "phi" in llm_model_id else False | |
| # To this (for Qwen and other models): | |
| # trust_remote_code=True if any(keyword in llm_model_id.lower() for keyword in ["phi", "qwen", "yi", "deepseek"]) else False | |
| print(f"LLM model {llm_model_id} loaded successfully!") | |
| except Exception as e: | |
| print(f"LLM loading failed: {e}") | |
| self.llm_model = None | |
| self.tokenizer = None | |
| # Common objects for scene understanding | |
| self.common_objects = [ | |
| 'person', 'people', 'human', 'man', 'woman', 'child', 'baby', | |
| 'dog', 'cat', 'animal', 'bird', 'horse', 'cow', 'sheep', | |
| 'car', 'vehicle', 'bus', 'truck', 'bicycle', 'motorcycle', | |
| 'building', 'house', 'skyscraper', 'architecture', | |
| 'tree', 'forest', 'nature', 'mountain', 'sky', 'clouds', | |
| 'water', 'ocean', 'river', 'lake', 'beach', | |
| 'food', 'fruit', 'vegetable', 'meal', | |
| 'indoor', 'outdoor', 'urban', 'rural' | |
| ] | |
| # Scene categories for classification | |
| self.scene_categories = [ | |
| "portrait", "landscape", "cityscape", "indoor scene", "outdoor scene", | |
| "nature", "urban", "beach", "mountain", "forest", "street", | |
| "party", "celebration", "sports", "action", "still life", | |
| "abstract", "art", "architecture", "wildlife", "pet" | |
| ] | |
| def analyze_image_with_clip(self, image): | |
| """Analyze image using CLIP to understand content and scene""" | |
| if self.clip_model is None or self.clip_processor is None: | |
| return self.fallback_image_analysis(image) | |
| try: | |
| # Convert PIL to RGB | |
| image_rgb = image.convert('RGB') | |
| # Analyze objects in the image | |
| object_inputs = self.clip_processor( | |
| text=self.common_objects, | |
| images=image_rgb, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| object_outputs = self.clip_model(**object_inputs) | |
| object_logits = object_outputs.logits_per_image | |
| object_probs = object_logits.softmax(dim=1) | |
| # Get top objects | |
| top_object_indices = torch.topk(object_probs, 5, dim=1).indices[0] | |
| detected_objects = [] | |
| for idx in top_object_indices: | |
| obj_name = self.common_objects[idx] | |
| confidence = object_probs[0][idx].item() | |
| if confidence > 0.1: # Confidence threshold | |
| detected_objects.append({ | |
| 'name': obj_name, | |
| 'confidence': confidence | |
| }) | |
| # Analyze scene type | |
| scene_inputs = self.clip_processor( | |
| text=self.scene_categories, | |
| images=image_rgb, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| scene_outputs = self.clip_model(**scene_inputs) | |
| scene_logits = scene_outputs.logits_per_image | |
| scene_probs = scene_logits.softmax(dim=1) | |
| top_scene_indices = torch.topk(scene_probs, 3, dim=1).indices[0] | |
| scene_types = [] | |
| for idx in top_scene_indices: | |
| scene_name = self.scene_categories[idx] | |
| confidence = scene_probs[0][idx].item() | |
| scene_types.append({ | |
| 'type': scene_name, | |
| 'confidence': confidence | |
| }) | |
| return { | |
| 'objects': detected_objects, | |
| 'scenes': scene_types, | |
| 'success': True | |
| } | |
| except Exception as e: | |
| print(f"CLIP analysis failed: {e}") | |
| return self.fallback_image_analysis(image) | |
| def fallback_image_analysis(self, image): | |
| """Fallback analysis when CLIP fails""" | |
| return { | |
| 'objects': [{'name': 'scene', 'confidence': 1.0}], | |
| 'scenes': [{'type': 'general image', 'confidence': 1.0}], | |
| 'success': False | |
| } | |
| def generate_story(self, analysis_result, creativity_level=0.7): | |
| """Generate a story with caption based on detected objects and scene using Qwen""" | |
| if self.llm_model is None: | |
| return "Story generation model not available." | |
| try: | |
| # Extract detected objects and scene | |
| objects = [obj['name'] for obj in analysis_result['objects']] | |
| scenes = [scene['type'] for scene in analysis_result['scenes']] | |
| # Create a prompt for the LLM | |
| objects_str = ", ".join(objects) | |
| scene_str = scenes[0] if scenes else "general scene" | |
| # Convert creativity_level to float if it's a tuple | |
| if isinstance(creativity_level, (tuple, list)): | |
| creativity_level = float(creativity_level[0]) | |
| # SIMPLIFIED PROMPT - No numbered lists or complex formatting | |
| if creativity_level > 0.8: | |
| prompt = f"Write a catchy 5-7 word YouTube-style caption, then a creative 3-4 paragraph story about {objects_str} in a {scene_str}." | |
| elif creativity_level > 0.5: | |
| prompt = f"Create a short caption and a 2-3 paragraph story about {objects_str} in a {scene_str}." | |
| else: | |
| prompt = f"Write a caption and a 1-2 paragraph description of {objects_str} in a {scene_str}." | |
| # QWEN FORMATTING | |
| if "qwen" in self.llm_model_id.lower(): | |
| formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| elif "phi" in self.llm_model_id: | |
| formatted_prompt = f"Instruct: {prompt}\nOutput:" | |
| elif "gemma" in self.llm_model_id: | |
| formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
| else: | |
| formatted_prompt = f"User: {prompt}\nAssistant:" | |
| # Tokenize and generate | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device) | |
| with torch.no_grad(): | |
| outputs = self.llm_model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| temperature=creativity_level, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| no_repeat_ngram_size=3 | |
| ) | |
| # Decode and clean up | |
| raw_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "assistant" in raw_output.lower(): | |
| parts = raw_output.lower().split("assistant") | |
| if len(parts) > 1: | |
| story = parts[-1].strip() | |
| else: | |
| story = raw_output | |
| elif "Assistant:" in raw_output: | |
| parts = raw_output.split("Assistant:") | |
| story = parts[-1].strip() if len(parts) > 1 else raw_output | |
| else: | |
| story = raw_output | |
| # Clean Qwen tokens if present | |
| qwen_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] | |
| for token in qwen_tokens: | |
| story = story.replace(token, "").strip() | |
| # Clean any remaining prompt text | |
| story = story.replace(prompt, "").strip() | |
| # Extract or create caption from the story | |
| sentences = story.split('. ') | |
| if sentences: | |
| # Take first sentence as caption | |
| caption = sentences[0].strip() | |
| if not caption.endswith('.'): | |
| caption += '.' | |
| # Rest of the story | |
| if len(sentences) > 1: | |
| story_text = '. '.join(sentences[1:]) | |
| else: | |
| story_text = story.replace(caption, "").strip() | |
| # Format with caption at top and separator | |
| formatted_output = f"{caption}\n{'β' * 40}\n{story_text}" | |
| else: | |
| formatted_output = story | |
| # Clean up any extra whitespace | |
| formatted_output = '\n'.join([line.strip() for line in formatted_output.split('\n') if line.strip()]) | |
| return formatted_output | |
| except Exception as e: | |
| print(f"Story generation failed: {e}") | |
| objects_str = ", ".join(objects) if 'objects' in locals() else "unknown" | |
| scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene" | |
| return f"Caption: Analysis of {objects_str}\n{'β' * 40}\nFailed to generate story. Detected: {objects_str} in {scene_str}." | |
| def process_image_and_generate_story(self, image, creativity_level=0.7): | |
| """Complete pipeline: analyze image and generate story""" | |
| print("Analyzing image...") | |
| analysis = self.analyze_image_with_clip(image) | |
| print("Generating story...") | |
| story = self.generate_story(analysis, creativity_level) | |
| # Return both analysis and story | |
| detected_objects = [obj['name'] for obj in analysis['objects']] | |
| scene_type = analysis['scenes'][0]['type'] if analysis['scenes'] else "unknown" | |
| return story, detected_objects, scene_type | |
| def create_story_overlay(self, image, story): | |
| """Create formatted text with caption and story for textbox display""" | |
| # Generate caption (first sentence of the story) | |
| caption = "" | |
| sentences = story.split('. ') | |
| if sentences: | |
| caption = sentences[0].strip() | |
| if not caption.endswith('.'): | |
| caption += '.' | |
| # Format the text with caption separated from story | |
| # Using a separator line of dashes | |
| separator = "β" * 40 | |
| # Format the complete text for the textbox | |
| formatted_text = f"{caption}\n{separator}\n{story}" | |
| return formatted_text | |
| def remove_background(self, image): | |
| """Remove background using rembg""" | |
| try: | |
| # Convert PIL image to bytes | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| # Remove background | |
| output = rembg.remove(img_byte_arr) | |
| # Convert back to PIL Image | |
| result_image = Image.open(BytesIO(output)) | |
| return result_image | |
| except Exception as e: | |
| print(f"Background removal failed: {e}") | |
| return image | |
| def remove_foreground(self, image): | |
| """Remove foreground and keep only background using inpainting""" | |
| try: | |
| # First remove background to get foreground mask | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| # Remove background to get alpha channel | |
| output = rembg.remove(img_byte_arr) | |
| foreground_image = Image.open(BytesIO(output)) | |
| # Convert to numpy arrays | |
| original_np = np.array(image.convert('RGB')) | |
| foreground_np = np.array(foreground_image.convert('RGBA')) | |
| # Create mask where foreground exists (alpha > 0) | |
| mask = foreground_np[:, :, 3] > 0 | |
| # Create background-only image by filling foreground areas | |
| background_np = original_np.copy() | |
| # Simple inpainting: fill foreground areas with average background color | |
| # Calculate average background color from areas without foreground | |
| bg_pixels = original_np[~mask] | |
| if len(bg_pixels) > 0: | |
| avg_color = np.mean(bg_pixels, axis=0) | |
| background_np[mask] = avg_color.astype(np.uint8) | |
| return Image.fromarray(background_np) | |
| except Exception as e: | |
| print(f"Foreground removal failed: {e}") | |
| return image | |
| def process_image(self, image): | |
| """Main processing function""" | |
| try: | |
| # Analyze image with CLIP-ViT | |
| analysis_result = self.analyze_image_with_clip(image) | |
| # Generate story | |
| story = self.generate_story(analysis_result, creativity_level=0.7) | |
| # # Create analysis overlay | |
| # analysis_image = self.create_analysis_overlay(image, analysis_result) | |
| # Create story overlay | |
| story_image = self.create_story_overlay(image, story) | |
| return story_image | |
| except Exception as e: | |
| error_msg = f"An error occurred: {str(e)}" | |
| print(error_msg) | |
| # Return original images on error | |
| return image, image | |
| # Initialize the storyteller | |
| storyteller = ImageStoryteller() | |
| # Get example images from local directory | |
| def get_example_images(): | |
| """Get example images from local directory""" | |
| example_images = [] | |
| for i in range(1, 17): | |
| img_path = f"obj_{i:02d}.jpg" | |
| if os.path.exists(img_path): | |
| # Load and resize the image for the gallery | |
| img = Image.open(img_path) | |
| # Resize to smaller size for gallery display | |
| img.thumbnail((150, 150)) | |
| example_images.append(img) | |
| else: | |
| print(f"Warning: {img_path} not found") | |
| # Create a simple placeholder image | |
| placeholder = Image.new('RGB', (150, 150), color=(73, 109, 137)) | |
| example_images.append(placeholder) | |
| return example_images | |
| def load_selected_example(evt: gr.SelectData): | |
| """Load the full-size version of the selected example image""" | |
| if evt.index < 16: # We have 8 example images | |
| img_path = f"obj_{evt.index+1:02d}.jpg" | |
| if os.path.exists(img_path): | |
| return Image.open(img_path) | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(title="Who says AI isnβt creative? Watch it turn a single image into a beautifully written story", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Image Story Teller") | |
| gr.Markdown("**Upload an image to analyse content and generate stories**") | |
| # Load example images | |
| example_images_list = get_example_images() | |
| custom_css = """ | |
| <style> | |
| .gradio-container { | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: hidden !important; | |
| } | |
| #blocks-container { | |
| height: calc(100vh - 100px) !important; | |
| overflow-y: auto !important; | |
| } | |
| /* Remove gallery selection frames */ | |
| .gallery .wrap.contain .grid .wrap, | |
| .gallery .wrap.contain .grid .wrap.selected, | |
| .gallery .thumbnail, | |
| .gallery .thumbnail.selected, | |
| .gallery .wrap.gradio-image, | |
| .gallery .wrap.gradio-image.selected { | |
| border: none !important; | |
| box-shadow: none !important; | |
| outline: none !important; | |
| } | |
| </style> | |
| """ | |
| javascript = """ | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const stopExpansion = function() { | |
| // More aggressive containment | |
| document.body.style.maxHeight = '100vh'; | |
| document.body.style.overflow = 'hidden'; | |
| const containers = document.querySelectorAll('div'); | |
| containers.forEach(container => { | |
| if (container.scrollHeight > window.innerHeight) { | |
| container.style.maxHeight = '95vh'; | |
| container.style.overflowY = 'auto'; | |
| } | |
| }); | |
| }; | |
| stopExpansion(); | |
| setInterval(stopExpansion, 1000); // Keep checking every second | |
| }); | |
| </script> | |
| """ | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="πΌοΈ Upload Your Image", | |
| height=400 | |
| ) | |
| # Buttons row | |
| with gr.Row(): | |
| process_btn = gr.Button("β¨ Generate Story", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear Image", variant="secondary", size="lg") | |
| # Example images section | |
| gr.Markdown("### πΈ Example Images (Click to load)") | |
| # Display example images in a gallery with custom CSS to remove frames | |
| example_gallery = gr.Gallery( | |
| value=example_images_list, | |
| label="", | |
| columns=4, | |
| rows=2, | |
| height="auto", | |
| scale=2, | |
| object_fit="contain", | |
| show_label=False, | |
| show_download_button=False, | |
| container=True, | |
| preview=False, | |
| allow_preview=False, | |
| elem_id="example-gallery" | |
| ) | |
| with gr.Column(): | |
| story_output = gr.Textbox( | |
| label="π Story", | |
| # height=None, | |
| # show_download_button=True | |
| lines=10, | |
| max_lines=20, | |
| show_copy_button=True, | |
| interactive=False, | |
| autoscroll=False | |
| ) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # story_output = gr.Image( | |
| # label="π Story", | |
| # height=400, | |
| # show_download_button=True | |
| # ) | |
| # Background removal section | |
| with gr.Row(): | |
| with gr.Column(): | |
| bg_remove_btn = gr.Button("π― Remove Background", variant="secondary", size="lg") | |
| background_output = gr.Image( | |
| label="Background Removed", | |
| height=400, | |
| show_download_button=True | |
| ) | |
| with gr.Column(): | |
| fg_remove_btn = gr.Button("π― Remove Foreground", variant="secondary", size="lg") | |
| foreground_output = gr.Image( | |
| label="Foreground Removed", | |
| height=400, | |
| show_download_button=True | |
| ) | |
| def clear_all(): | |
| """Clear all images and outputs""" | |
| return None, None, None, None, None | |
| # Set up the processing | |
| process_btn.click( | |
| fn=storyteller.process_image, | |
| inputs=input_image, | |
| outputs=[story_output] | |
| ) | |
| # Clear button functionality | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[input_image, story_output, background_output, foreground_output] | |
| ) | |
| # Example gallery selection - load full-size image when clicked | |
| example_gallery.select( | |
| fn=load_selected_example, | |
| inputs=[], | |
| outputs=input_image | |
| ) | |
| # Background removal | |
| bg_remove_btn.click( | |
| fn=storyteller.remove_background, | |
| inputs=input_image, | |
| outputs=background_output | |
| ) | |
| # Foreground removal | |
| fg_remove_btn.click( | |
| fn=storyteller.remove_foreground, | |
| inputs=input_image, | |
| outputs=foreground_output | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |