Tas01's picture
Update app.py
c95b9e0 verified
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import rembg
from io import BytesIO
import os
import torch
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from PIL import Image
import gradio as gr
class ImageStoryteller:
# def __init__(self, llm_model_id="microsoft/phi-2"): microsoft/phi-3-mini-4k-instruct
def __init__(self, llm_model_id="Qwen/Qwen1.5-1.8B-Chat"):
print("Initializing Image Storyteller with CLIP-ViT and LLM...")
# Load CLIP model for image understanding
try:
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("CLIP-ViT model loaded successfully!")
except Exception as e:
print(f"CLIP loading failed: {e}")
self.clip_model = None
self.clip_processor = None
# Load LLM for story generation
try:
# Choose your LLM (phi-2 doesn't require login)
self.llm_model_id = llm_model_id
self.tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
self.llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True)
# trust_remote_code=True if "phi" in llm_model_id else False
# To this (for Qwen and other models):
# trust_remote_code=True if any(keyword in llm_model_id.lower() for keyword in ["phi", "qwen", "yi", "deepseek"]) else False
print(f"LLM model {llm_model_id} loaded successfully!")
except Exception as e:
print(f"LLM loading failed: {e}")
self.llm_model = None
self.tokenizer = None
# Common objects for scene understanding
self.common_objects = [
'person', 'people', 'human', 'man', 'woman', 'child', 'baby',
'dog', 'cat', 'animal', 'bird', 'horse', 'cow', 'sheep',
'car', 'vehicle', 'bus', 'truck', 'bicycle', 'motorcycle',
'building', 'house', 'skyscraper', 'architecture',
'tree', 'forest', 'nature', 'mountain', 'sky', 'clouds',
'water', 'ocean', 'river', 'lake', 'beach',
'food', 'fruit', 'vegetable', 'meal',
'indoor', 'outdoor', 'urban', 'rural'
]
# Scene categories for classification
self.scene_categories = [
"portrait", "landscape", "cityscape", "indoor scene", "outdoor scene",
"nature", "urban", "beach", "mountain", "forest", "street",
"party", "celebration", "sports", "action", "still life",
"abstract", "art", "architecture", "wildlife", "pet"
]
def analyze_image_with_clip(self, image):
"""Analyze image using CLIP to understand content and scene"""
if self.clip_model is None or self.clip_processor is None:
return self.fallback_image_analysis(image)
try:
# Convert PIL to RGB
image_rgb = image.convert('RGB')
# Analyze objects in the image
object_inputs = self.clip_processor(
text=self.common_objects,
images=image_rgb,
return_tensors="pt",
padding=True
)
with torch.no_grad():
object_outputs = self.clip_model(**object_inputs)
object_logits = object_outputs.logits_per_image
object_probs = object_logits.softmax(dim=1)
# Get top objects
top_object_indices = torch.topk(object_probs, 5, dim=1).indices[0]
detected_objects = []
for idx in top_object_indices:
obj_name = self.common_objects[idx]
confidence = object_probs[0][idx].item()
if confidence > 0.1: # Confidence threshold
detected_objects.append({
'name': obj_name,
'confidence': confidence
})
# Analyze scene type
scene_inputs = self.clip_processor(
text=self.scene_categories,
images=image_rgb,
return_tensors="pt",
padding=True
)
with torch.no_grad():
scene_outputs = self.clip_model(**scene_inputs)
scene_logits = scene_outputs.logits_per_image
scene_probs = scene_logits.softmax(dim=1)
top_scene_indices = torch.topk(scene_probs, 3, dim=1).indices[0]
scene_types = []
for idx in top_scene_indices:
scene_name = self.scene_categories[idx]
confidence = scene_probs[0][idx].item()
scene_types.append({
'type': scene_name,
'confidence': confidence
})
return {
'objects': detected_objects,
'scenes': scene_types,
'success': True
}
except Exception as e:
print(f"CLIP analysis failed: {e}")
return self.fallback_image_analysis(image)
def fallback_image_analysis(self, image):
"""Fallback analysis when CLIP fails"""
return {
'objects': [{'name': 'scene', 'confidence': 1.0}],
'scenes': [{'type': 'general image', 'confidence': 1.0}],
'success': False
}
def generate_story(self, analysis_result, creativity_level=0.7):
"""Generate a story with caption based on detected objects and scene using Qwen"""
if self.llm_model is None:
return "Story generation model not available."
try:
# Extract detected objects and scene
objects = [obj['name'] for obj in analysis_result['objects']]
scenes = [scene['type'] for scene in analysis_result['scenes']]
# Create a prompt for the LLM
objects_str = ", ".join(objects)
scene_str = scenes[0] if scenes else "general scene"
# Convert creativity_level to float if it's a tuple
if isinstance(creativity_level, (tuple, list)):
creativity_level = float(creativity_level[0])
# SIMPLIFIED PROMPT - No numbered lists or complex formatting
if creativity_level > 0.8:
prompt = f"Write a catchy 5-7 word YouTube-style caption, then a creative 3-4 paragraph story about {objects_str} in a {scene_str}."
elif creativity_level > 0.5:
prompt = f"Create a short caption and a 2-3 paragraph story about {objects_str} in a {scene_str}."
else:
prompt = f"Write a caption and a 1-2 paragraph description of {objects_str} in a {scene_str}."
# QWEN FORMATTING
if "qwen" in self.llm_model_id.lower():
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
elif "phi" in self.llm_model_id:
formatted_prompt = f"Instruct: {prompt}\nOutput:"
elif "gemma" in self.llm_model_id:
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"User: {prompt}\nAssistant:"
# Tokenize and generate
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.llm_model.device)
with torch.no_grad():
outputs = self.llm_model.generate(
**inputs,
max_new_tokens=300,
temperature=creativity_level,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=3
)
# Decode and clean up
raw_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "assistant" in raw_output.lower():
parts = raw_output.lower().split("assistant")
if len(parts) > 1:
story = parts[-1].strip()
else:
story = raw_output
elif "Assistant:" in raw_output:
parts = raw_output.split("Assistant:")
story = parts[-1].strip() if len(parts) > 1 else raw_output
else:
story = raw_output
# Clean Qwen tokens if present
qwen_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
for token in qwen_tokens:
story = story.replace(token, "").strip()
# Clean any remaining prompt text
story = story.replace(prompt, "").strip()
# Extract or create caption from the story
sentences = story.split('. ')
if sentences:
# Take first sentence as caption
caption = sentences[0].strip()
if not caption.endswith('.'):
caption += '.'
# Rest of the story
if len(sentences) > 1:
story_text = '. '.join(sentences[1:])
else:
story_text = story.replace(caption, "").strip()
# Format with caption at top and separator
formatted_output = f"{caption}\n{'─' * 40}\n{story_text}"
else:
formatted_output = story
# Clean up any extra whitespace
formatted_output = '\n'.join([line.strip() for line in formatted_output.split('\n') if line.strip()])
return formatted_output
except Exception as e:
print(f"Story generation failed: {e}")
objects_str = ", ".join(objects) if 'objects' in locals() else "unknown"
scene_str = scenes[0] if 'scenes' in locals() and scenes else "unknown scene"
return f"Caption: Analysis of {objects_str}\n{'─' * 40}\nFailed to generate story. Detected: {objects_str} in {scene_str}."
def process_image_and_generate_story(self, image, creativity_level=0.7):
"""Complete pipeline: analyze image and generate story"""
print("Analyzing image...")
analysis = self.analyze_image_with_clip(image)
print("Generating story...")
story = self.generate_story(analysis, creativity_level)
# Return both analysis and story
detected_objects = [obj['name'] for obj in analysis['objects']]
scene_type = analysis['scenes'][0]['type'] if analysis['scenes'] else "unknown"
return story, detected_objects, scene_type
def create_story_overlay(self, image, story):
"""Create formatted text with caption and story for textbox display"""
# Generate caption (first sentence of the story)
caption = ""
sentences = story.split('. ')
if sentences:
caption = sentences[0].strip()
if not caption.endswith('.'):
caption += '.'
# Format the text with caption separated from story
# Using a separator line of dashes
separator = "─" * 40
# Format the complete text for the textbox
formatted_text = f"{caption}\n{separator}\n{story}"
return formatted_text
def remove_background(self, image):
"""Remove background using rembg"""
try:
# Convert PIL image to bytes
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Remove background
output = rembg.remove(img_byte_arr)
# Convert back to PIL Image
result_image = Image.open(BytesIO(output))
return result_image
except Exception as e:
print(f"Background removal failed: {e}")
return image
def remove_foreground(self, image):
"""Remove foreground and keep only background using inpainting"""
try:
# First remove background to get foreground mask
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Remove background to get alpha channel
output = rembg.remove(img_byte_arr)
foreground_image = Image.open(BytesIO(output))
# Convert to numpy arrays
original_np = np.array(image.convert('RGB'))
foreground_np = np.array(foreground_image.convert('RGBA'))
# Create mask where foreground exists (alpha > 0)
mask = foreground_np[:, :, 3] > 0
# Create background-only image by filling foreground areas
background_np = original_np.copy()
# Simple inpainting: fill foreground areas with average background color
# Calculate average background color from areas without foreground
bg_pixels = original_np[~mask]
if len(bg_pixels) > 0:
avg_color = np.mean(bg_pixels, axis=0)
background_np[mask] = avg_color.astype(np.uint8)
return Image.fromarray(background_np)
except Exception as e:
print(f"Foreground removal failed: {e}")
return image
def process_image(self, image):
"""Main processing function"""
try:
# Analyze image with CLIP-ViT
analysis_result = self.analyze_image_with_clip(image)
# Generate story
story = self.generate_story(analysis_result, creativity_level=0.7)
# # Create analysis overlay
# analysis_image = self.create_analysis_overlay(image, analysis_result)
# Create story overlay
story_image = self.create_story_overlay(image, story)
return story_image
except Exception as e:
error_msg = f"An error occurred: {str(e)}"
print(error_msg)
# Return original images on error
return image, image
# Initialize the storyteller
storyteller = ImageStoryteller()
# Get example images from local directory
def get_example_images():
"""Get example images from local directory"""
example_images = []
for i in range(1, 17):
img_path = f"obj_{i:02d}.jpg"
if os.path.exists(img_path):
# Load and resize the image for the gallery
img = Image.open(img_path)
# Resize to smaller size for gallery display
img.thumbnail((150, 150))
example_images.append(img)
else:
print(f"Warning: {img_path} not found")
# Create a simple placeholder image
placeholder = Image.new('RGB', (150, 150), color=(73, 109, 137))
example_images.append(placeholder)
return example_images
def load_selected_example(evt: gr.SelectData):
"""Load the full-size version of the selected example image"""
if evt.index < 16: # We have 8 example images
img_path = f"obj_{evt.index+1:02d}.jpg"
if os.path.exists(img_path):
return Image.open(img_path)
return None
# Create Gradio interface
with gr.Blocks(title="Who says AI isn’t creative? Watch it turn a single image into a beautifully written story", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Image Story Teller")
gr.Markdown("**Upload an image to analyse content and generate stories**")
# Load example images
example_images_list = get_example_images()
custom_css = """
<style>
.gradio-container {
height: 100vh !important;
max-height: 100vh !important;
overflow: hidden !important;
}
#blocks-container {
height: calc(100vh - 100px) !important;
overflow-y: auto !important;
}
/* Remove gallery selection frames */
.gallery .wrap.contain .grid .wrap,
.gallery .wrap.contain .grid .wrap.selected,
.gallery .thumbnail,
.gallery .thumbnail.selected,
.gallery .wrap.gradio-image,
.gallery .wrap.gradio-image.selected {
border: none !important;
box-shadow: none !important;
outline: none !important;
}
</style>
"""
javascript = """
<script>
document.addEventListener('DOMContentLoaded', function() {
const stopExpansion = function() {
// More aggressive containment
document.body.style.maxHeight = '100vh';
document.body.style.overflow = 'hidden';
const containers = document.querySelectorAll('div');
containers.forEach(container => {
if (container.scrollHeight > window.innerHeight) {
container.style.maxHeight = '95vh';
container.style.overflowY = 'auto';
}
});
};
stopExpansion();
setInterval(stopExpansion, 1000); // Keep checking every second
});
</script>
"""
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="πŸ–ΌοΈ Upload Your Image",
height=400
)
# Buttons row
with gr.Row():
process_btn = gr.Button("✨ Generate Story", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear Image", variant="secondary", size="lg")
# Example images section
gr.Markdown("### πŸ“Έ Example Images (Click to load)")
# Display example images in a gallery with custom CSS to remove frames
example_gallery = gr.Gallery(
value=example_images_list,
label="",
columns=4,
rows=2,
height="auto",
scale=2,
object_fit="contain",
show_label=False,
show_download_button=False,
container=True,
preview=False,
allow_preview=False,
elem_id="example-gallery"
)
with gr.Column():
story_output = gr.Textbox(
label="πŸ” Story",
# height=None,
# show_download_button=True
lines=10,
max_lines=20,
show_copy_button=True,
interactive=False,
autoscroll=False
)
# with gr.Row():
# with gr.Column():
# story_output = gr.Image(
# label="πŸ“– Story",
# height=400,
# show_download_button=True
# )
# Background removal section
with gr.Row():
with gr.Column():
bg_remove_btn = gr.Button("🎯 Remove Background", variant="secondary", size="lg")
background_output = gr.Image(
label="Background Removed",
height=400,
show_download_button=True
)
with gr.Column():
fg_remove_btn = gr.Button("🎯 Remove Foreground", variant="secondary", size="lg")
foreground_output = gr.Image(
label="Foreground Removed",
height=400,
show_download_button=True
)
def clear_all():
"""Clear all images and outputs"""
return None, None, None, None, None
# Set up the processing
process_btn.click(
fn=storyteller.process_image,
inputs=input_image,
outputs=[story_output]
)
# Clear button functionality
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[input_image, story_output, background_output, foreground_output]
)
# Example gallery selection - load full-size image when clicked
example_gallery.select(
fn=load_selected_example,
inputs=[],
outputs=input_image
)
# Background removal
bg_remove_btn.click(
fn=storyteller.remove_background,
inputs=input_image,
outputs=background_output
)
# Foreground removal
fg_remove_btn.click(
fn=storyteller.remove_foreground,
inputs=input_image,
outputs=foreground_output
)
# Launch the application
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)