import gradio as gr import cv2 import numpy as np from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import rembg from io import BytesIO import os import warnings # Suppress warnings warnings.filterwarnings("ignore") class ImageStoryteller: # def __init__(self, llm_model_id="Qwen/Qwen2.5-3B-Instruct"): def __init__(self, llm_model_id="Qwen/Qwen2.5-1.5B-Instruct"): # def __init__(self, llm_model_id="microsoft/Phi-3-mini-4k-instruct"): # Load CLIP model for image understanding try: self.clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") self.clip_processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") print("CLIP-ViT model loaded successfully!") except Exception as e: print(f"CLIP loading failed: {e}") self.clip_model = None self.clip_processor = None # Load LLM for story generation try: self.llm_model_id = llm_model_id self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id) self.llm_model = AutoModelForCausalLM.from_pretrained( llm_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) print(f"LLM model {llm_model_id} loaded successfully!") except Exception as e: print(f"LLM loading failed: {e}") self.llm_model = None self.tokenizer = None # Common objects and scenes (truncated for brevity - keep your full lists) self.common_objects = [ # People & Faces (25 items) 'person', 'people', 'human', 'man', 'woman', 'child', 'baby', 'face', 'head', 'hand', 'foot', 'body', 'crowd', 'group', 'family', 'couple', 'friends', 'audience', 'team', 'worker', 'athlete', 'dancer', 'singer', 'artist', 'doctor', # Animals (25 items) 'dog', 'cat', 'bird', 'horse', 'cow', 'sheep', 'pig', 'elephant', 'lion', 'tiger', 'bear', 'wolf', 'fox', 'deer', 'rabbit', 'squirrel', 'butterfly', 'fish', 'shark', 'whale', 'dolphin', 'turtle', 'snake', 'spider', 'insect', # Vehicles (20 items) 'car', 'truck', 'bus', 'motorcycle', 'bicycle', 'airplane', 'helicopter', 'train', 'boat', 'ship', 'sailboat', 'submarine', 'rocket', 'tractor', 'ambulance', 'fire truck', 'police car', 'taxi', 'racing car', 'bike', # Buildings & Structures (20 items) 'building', 'house', 'skyscraper', 'tower', 'castle', 'bridge', 'monument', 'statue', 'fountain', 'church', 'temple', 'mosque', 'school', 'hospital', 'hotel', 'restaurant', 'store', 'mall', 'factory', 'lighthouse', # Nature & Outdoor (30 items) 'tree', 'forest', 'flower', 'plant', 'grass', 'mountain', 'hill', 'valley', 'cliff', 'cave', 'water', 'ocean', 'sea', 'river', 'lake', 'waterfall', 'beach', 'sand', 'rock', 'stone', 'sky', 'cloud', 'sun', 'moon', 'star', 'rain', 'snow', 'ice', 'fire', 'smoke', # Food & Drinks (20 items) 'food', 'fruit', 'vegetable', 'bread', 'pizza', 'cake', 'dessert', 'ice cream', 'chocolate', 'coffee', 'tea', 'wine', 'beer', 'water', 'meal', 'breakfast', 'lunch', 'dinner', 'restaurant', 'kitchen', # Furniture & Household (20 items) 'chair', 'table', 'bed', 'sofa', 'couch', 'desk', 'lamp', 'clock', 'mirror', 'window', 'door', 'stairs', 'shelf', 'cabinet', 'refrigerator', 'oven', 'sink', 'toilet', 'shower', 'bathtub', # Electronics & Items (20 items) 'computer', 'laptop', 'phone', 'television', 'camera', 'book', 'newspaper', 'pen', 'paper', 'keyboard', 'mouse', 'headphones', 'speaker', 'microphone', 'watch', 'glasses', 'sunglasses', 'umbrella', 'bag', 'backpack', # Clothing (20 items) 'clothing', 'shirt', 'pants', 'dress', 'skirt', 'jacket', 'coat', 'shoes', 'boots', 'sneakers', 'hat', 'cap', 'helmet', 'gloves', 'scarf', 'tie', 'belt', 'jewelry', 'necklace', 'ring' ] self.scene_categories = [ # 50 Most Common Scene Types "portrait", "landscape", "cityscape", "indoor", "outdoor", "nature", "urban", "beach", "mountain", "forest", "street", "road", "park", "garden", "field", "room", "kitchen", "bedroom", "living room", "office", "restaurant", "cafe", "store", "mall", "school", "sports", "game", "concert", "party", "wedding", "food", "meal", "cooking", "drinking", "eating", "animal", "pet", "wildlife", "zoo", "farm", "vehicle", "traffic", "transportation", "travel", "journey", "art", "painting", "drawing", "photography", "design" ] def analyze_image_with_clip(self, image): """Analyze image using CLIP to understand content and scene""" if self.clip_model is None or self.clip_processor is None: return self.fallback_image_analysis(image) try: # Convert PIL to RGB image_rgb = image.convert('RGB') # Analyze objects in the image (process in smaller batches) batch_size = 50 all_object_probs = [] for i in range(0, len(self.common_objects), batch_size): batch_objects = self.common_objects[i:i + batch_size] object_inputs = self.clip_processor( text=batch_objects, images=image_rgb, return_tensors="pt", padding=True ) with torch.no_grad(): object_outputs = self.clip_model(**object_inputs) object_logits = object_outputs.logits_per_image object_probs = object_logits.softmax(dim=1) all_object_probs.append(object_probs) # Combine results if len(all_object_probs) > 1: # Stack and normalize combined_probs = torch.cat(all_object_probs, dim=1) combined_probs = combined_probs / combined_probs.sum(dim=1, keepdim=True) else: combined_probs = all_object_probs[0] # Get top objects top_k = min(5, len(self.common_objects)) top_object_indices = torch.topk(combined_probs, top_k, dim=1).indices[0] detected_objects = [] for idx in top_object_indices: obj_name = self.common_objects[idx] confidence = combined_probs[0][idx].item() if confidence > 0.1: detected_objects.append({ 'name': obj_name, 'confidence': confidence }) # Analyze scene type (similar batching) batch_size = 30 all_scene_probs = [] for i in range(0, len(self.scene_categories), batch_size): batch_scenes = self.scene_categories[i:i + batch_size] scene_inputs = self.clip_processor( text=batch_scenes, images=image_rgb, return_tensors="pt", padding=True ) with torch.no_grad(): scene_outputs = self.clip_model(**scene_inputs) scene_logits = scene_outputs.logits_per_image scene_probs = scene_logits.softmax(dim=1) all_scene_probs.append(scene_probs) # Combine scene results if len(all_scene_probs) > 1: combined_scene_probs = torch.cat(all_scene_probs, dim=1) combined_scene_probs = combined_scene_probs / combined_scene_probs.sum(dim=1, keepdim=True) else: combined_scene_probs = all_scene_probs[0] top_scene_indices = torch.topk(combined_scene_probs, 3, dim=1).indices[0] scene_types = [] for idx in top_scene_indices: scene_name = self.scene_categories[idx] confidence = combined_scene_probs[0][idx].item() scene_types.append({ 'type': scene_name, 'confidence': confidence }) return { 'objects': detected_objects, 'scenes': scene_types, 'success': True } except Exception as e: print(f"CLIP analysis failed: {e}") return self.fallback_image_analysis(image) def fallback_image_analysis(self, image): """Fallback analysis when CLIP fails""" return { 'objects': [{'name': 'scene', 'confidence': 1.0}], 'scenes': [{'type': 'general image', 'confidence': 1.0}], 'success': False } def generate_story(self, analysis_result, creativity_level=0.7): """Generate a story with caption based on detected objects and scene""" if self.llm_model is None: return "Story generation model not available." try: # Extract detected objects and scene objects = [obj['name'] for obj in analysis_result['objects']] scenes = [scene['type'] for scene in analysis_result['scenes']] # Create a prompt for the LLM objects_str = ", ".join(objects[:5]) # Use top 5 objects scene_str = scenes[0] if scenes else "general scene" # Convert creativity_level to float if needed if isinstance(creativity_level, (tuple, list)): creativity_level = float(creativity_level[0]) # Simple prompt prompt = f"Write a creative story about {objects_str} in a {scene_str}. First give a short caption, then a story." # Format for Qwen if "qwen" in self.llm_model_id.lower(): messages = [ {"role": "user", "content": prompt} ] formatted_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: formatted_prompt = f"User: {prompt}\nAssistant:" # Tokenize and generate inputs = self.tokenizer(formatted_prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(self.llm_model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.llm_model.generate( **inputs, max_new_tokens=300, temperature=creativity_level, do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.eos_token_id, use_cache=True, past_key_values=None ) # Decode output story = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response if "assistant" in story.lower(): parts = story.lower().split("assistant") story = parts[-1].strip() elif "Assistant:" in story: parts = story.split("Assistant:") story = parts[-1].strip() if len(parts) > 1 else story # Clean up story = story.replace(prompt, "").strip() # Format with separator lines = story.split('\n') if len(lines) > 1: formatted = f"{lines[0]}\n{'─' * 40}\n" + '\n'.join(lines[1:]) else: formatted = story return formatted except Exception as e: print(f"Story generation failed: {e}") return f"Error generating story: {str(e)}" def process_image_and_generate_story(self, image, creativity_level=0.7): """Complete pipeline: analyze image and generate story""" if image is None: return "Please upload an image first.", [], "No image" print("Analyzing image...") analysis = self.analyze_image_with_clip(image) print("Generating story...") story = self.generate_story(analysis, creativity_level) # Return analysis details detected_objects = [obj['name'] for obj in analysis['objects']] scene_type = analysis['scenes'][0]['type'] if analysis['scenes'] else "unknown" return story, detected_objects, scene_type def remove_background(self, image): """Remove background using rembg""" if image is None: return None try: img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() output = rembg.remove(img_byte_arr) result_image = Image.open(BytesIO(output)) return result_image except Exception as e: print(f"Background removal failed: {e}") return image def remove_foreground(self, image): """Remove foreground and keep only background""" if image is None: return None try: # Remove background first img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() output = rembg.remove(img_byte_arr) foreground_image = Image.open(BytesIO(output)) # Convert to numpy arrays original_np = np.array(image.convert('RGB')) foreground_np = np.array(foreground_image.convert('RGBA')) # Create mask mask = foreground_np[:, :, 3] > 0 # Create background-only image background_np = original_np.copy() # Fill foreground areas with average background color bg_pixels = original_np[~mask] if len(bg_pixels) > 0: avg_color = np.mean(bg_pixels, axis=0) background_np[mask] = avg_color.astype(np.uint8) return Image.fromarray(background_np) except Exception as e: print(f"Foreground removal failed: {e}") return image # Initialize the storyteller storyteller = ImageStoryteller() def get_example_images(): """Get example images from local directory""" example_images = [] for i in range(1, 21): img_path = f"obj_{i:02d}.jpg" if os.path.exists(img_path): try: img = Image.open(img_path) img.thumbnail((150, 150)) example_images.append((img, f"Example {i}")) except: placeholder = Image.new('RGB', (150, 150), color=(73, 109, 137)) example_images.append((placeholder, f"Placeholder {i}")) return example_images def load_selected_example(evt: gr.SelectData): """Load the full-size version of the selected example image""" if evt.index < 20: img_path = f"obj_{evt.index+1:02d}.jpg" if os.path.exists(img_path): return Image.open(img_path) return None def clear_all(): """Clear all inputs and outputs""" return None, "", "", "", None, None # Create Gradio interface with gr.Blocks( title="Image Story Teller - Turn images into stories", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1400px !important; margin: auto !important; } .gallery .thumb { border: none !important; box-shadow: none !important; } .gallery .thumb.selected { border: 2px solid #4CAF50 !important; } """ ) as demo: gr.Markdown("# 🎨 Image Story Teller") gr.Markdown("Upload an image to analyze content and generate creative stories") # Get example images example_images_list = get_example_images() with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( type="pil", label="📤 Upload Your Image", height=400, interactive=True ) with gr.Row(): process_btn = gr.Button( "✨ Generate Story", variant="primary", size="lg" ) clear_btn = gr.Button( "🗑️ Clear All", variant="secondary", size="lg" ) # Creativity slider creativity_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Creativity Level", info="Higher = more creative, Lower = more factual" ) gr.Markdown("### 📸 Example Images") example_gallery = gr.Gallery( value=[img for img, _ in example_images_list], label="Click an image to load it", columns=4, rows=2, height="auto", object_fit="contain", show_label=True, allow_preview=False, preview=False ) with gr.Column(scale=1): story_output = gr.Textbox( label="📖 Generated Story", lines=15, max_lines=20, interactive=False, # show_copy_button=True ) with gr.Accordion("📊 Analysis Details", open=False): objects_output = gr.Textbox( label="Detected Objects", interactive=False, lines=3 ) scene_output = gr.Textbox( label="Scene Type", interactive=False, lines=2 ) with gr.Row(): with gr.Column(): bg_remove_btn = gr.Button( "🎯 Remove Background", variant="secondary", size="lg" ) background_output = gr.Image( label="Background Removed", height=300, interactive=False ) with gr.Column(): fg_remove_btn = gr.Button( "🎯 Remove Foreground", variant="secondary", size="lg" ) foreground_output = gr.Image( label="Foreground Removed", height=300, interactive=False ) # Event handlers process_btn.click( fn=lambda img, creativity: storyteller.process_image_and_generate_story(img, creativity), inputs=[input_image, creativity_slider], outputs=[story_output, objects_output, scene_output] ) clear_btn.click( fn=clear_all, inputs=[], outputs=[input_image, story_output, objects_output, scene_output, background_output, foreground_output] # ALL outputs here ) example_gallery.select( fn=load_selected_example, inputs=[], outputs=input_image ) bg_remove_btn.click( fn=storyteller.remove_background, inputs=input_image, outputs=background_output ) fg_remove_btn.click( fn=storyteller.remove_foreground, inputs=input_image, outputs=foreground_output ) # Launch the application if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, favicon_path=None, inbrowser=True )