adityaardak commited on
Commit
9a6f0ec
·
verified ·
1 Parent(s): 3bd8cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -75
app.py CHANGED
@@ -1,79 +1,173 @@
1
- diff --git a/app.py b/app.py
2
- index 0000000..1111111 100644
3
- --- a/app.py
4
- +++ b/app.py
5
- @@ -1,16 +1,28 @@
6
- import gradio as gr
7
- import torch
8
- from PIL import Image
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- -import spaces
11
 
12
- # Model configuration
13
- MID = "apple/FastVLM-0.5B"
14
- IMAGE_TOKEN_INDEX = -200
15
 
16
- # Load model and tokenizer (will be loaded on first GPU allocation)
17
- tok = None
18
- model = None
19
- def load_model():
20
- global tok, model
21
- if tok is None or model is None:
22
- print("Loading model...")
23
- tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
24
- - model = AutoModelForCausalLM.from_pretrained(
25
- - MID,
26
- - torch_dtype=torch.float16,
27
- - device_map="cuda",
28
- - trust_remote_code=True,
29
- - )
30
- + # ---- CPU-first, with dynamic fallback if CUDA exists ----
31
- + use_cuda = torch.cuda.is_available()
32
- + device_map = "cuda" if use_cuda else "cpu"
33
- + # float16 is great on GPU, but unsafe on CPU; use float32 on CPU
34
- + dtype = torch.float16 if use_cuda else torch.float32
35
- +
36
- + model = AutoModelForCausalLM.from_pretrained(
37
- + MID,
38
- + torch_dtype=dtype,
39
- + device_map=device_map,
40
- + trust_remote_code=True,
41
- + )
42
- print("Model loaded successfully!")
43
- return tok, model
44
- -
45
- [email protected](duration=60)
46
- +
47
- +# Removed GPU decorator so CPU Spaces don't request a GPU
48
- def caption_image(image, custom_prompt=None):
49
- @@ -66,16 +78,23 @@ def caption_image(image, custom_prompt=None):
50
- # Insert IMAGE token id at placeholder position
51
- - img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
52
- - input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
53
- - attention_mask = torch.ones_like(input_ids, device=model.device)
54
- + # Derive device/dtype from model parameters (robust on CPU or GPU)
55
- + model_device = next(model.parameters()).device
56
- + model_dtype = next(model.parameters()).dtype
57
- +
58
- + img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype, device=model_device)
59
- + input_ids = torch.cat([pre_ids.to(model_device), img_tok, post_ids.to(model_device)], dim=1)
60
- + attention_mask = torch.ones_like(input_ids, device=model_device)
61
 
62
- # Preprocess image using model's vision tower
63
- px = model.get_vision_tower().image_processor(
64
- images=image, return_tensors="pt"
65
- )["pixel_values"]
66
- - px = px.to(model.device, dtype=model.dtype)
67
- + px = px.to(model_device, dtype=model_dtype)
 
 
 
 
 
 
 
68
 
69
- # Generate caption
70
- with torch.no_grad():
71
- out = model.generate(
72
- inputs=input_ids,
73
- attention_mask=attention_mask,
74
- images=px,
75
- max_new_tokens=128,
76
- do_sample=False, # Deterministic generation
77
- - temperature=1.0,
78
- + # temperature is ignored when do_sample=False
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import spaces
 
 
 
 
 
6
 
7
+ # Model configuration
8
+ MID = "apple/FastVLM-0.5B"
9
+ IMAGE_TOKEN_INDEX = -200
10
 
11
+ # Load model and tokenizer (will be loaded on first GPU allocation)
12
+ tok = None
13
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def load_model():
16
+ global tok, model
17
+ if tok is None or model is None:
18
+ print("Loading model...")
19
+ tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MID,
22
+ torch_dtype=torch.float16,
23
+ device_map="cuda",
24
+ trust_remote_code=True,
25
+ )
26
+ print("Model loaded successfully!")
27
+ return tok, model
28
 
29
+ @spaces.GPU(duration=60)
30
+ def caption_image(image, custom_prompt=None):
31
+ """
32
+ Generate a caption for the input image.
33
+
34
+ Args:
35
+ image: PIL Image from Gradio
36
+ custom_prompt: Optional custom prompt to use instead of default
37
+
38
+ Returns:
39
+ Generated caption text
40
+ """
41
+ if image is None:
42
+ return "Please upload an image first."
43
+
44
+ try:
45
+ # Load model if not already loaded
46
+ tok, model = load_model()
47
+ # Convert image to RGB if needed
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
50
+
51
+ # Use custom prompt or default
52
+ prompt = custom_prompt if custom_prompt else "Describe this image in detail."
53
+
54
+ # Build chat message
55
+ messages = [
56
+ {"role": "user", "content": f"<image>\n{prompt}"}
57
+ ]
58
+
59
+ # Render to string to place <image> token correctly
60
+ rendered = tok.apply_chat_template(
61
+ messages, add_generation_prompt=True, tokenize=False
62
+ )
63
+
64
+ # Split at image token
65
+ pre, post = rendered.split("<image>", 1)
66
+
67
+ # Tokenize text around the image token
68
+ pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
69
+ post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
70
+
71
+ # Insert IMAGE token id at placeholder position
72
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
73
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
74
+ attention_mask = torch.ones_like(input_ids, device=model.device)
75
+
76
+ # Preprocess image using model's vision tower
77
+ px = model.get_vision_tower().image_processor(
78
+ images=image, return_tensors="pt"
79
+ )["pixel_values"]
80
+ px = px.to(model.device, dtype=model.dtype)
81
+
82
+ # Generate caption
83
+ with torch.no_grad():
84
+ out = model.generate(
85
+ inputs=input_ids,
86
+ attention_mask=attention_mask,
87
+ images=px,
88
+ max_new_tokens=128,
89
+ do_sample=False, # Deterministic generation
90
+ temperature=1.0,
91
+ )
92
+
93
+ # Decode and return the generated text
94
+ generated_text = tok.decode(out[0], skip_special_tokens=True)
95
+
96
+ # Extract only the assistant's response
97
+ if "assistant" in generated_text:
98
+ response = generated_text.split("assistant")[-1].strip()
99
+ else:
100
+ response = generated_text
101
+
102
+ return response
103
+
104
+ except Exception as e:
105
+ return f"Error generating caption: {str(e)}"
106
+
107
+ # Create Gradio interface
108
+ with gr.Blocks(title="FastVLM Image Captioning") as demo:
109
+ gr.Markdown(
110
+ """
111
+ # 🖼️ FastVLM Image Captioning
112
+
113
+ Upload an image to generate a detailed caption using Apple's FastVLM-0.5B model.
114
+ You can use the default prompt or provide your own custom prompt.
115
+ """
116
+ )
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ image_input = gr.Image(
121
+ type="pil",
122
+ label="Upload Image",
123
+ elem_id="image-upload"
124
+ )
125
+
126
+ custom_prompt = gr.Textbox(
127
+ label="Custom Prompt (Optional)",
128
+ placeholder="Leave empty for default: 'Describe this image in detail.'",
129
+ lines=2
130
+ )
131
+
132
+ with gr.Row():
133
+ clear_btn = gr.ClearButton([image_input, custom_prompt])
134
+ generate_btn = gr.Button("Generate Caption", variant="primary")
135
+
136
+ with gr.Column():
137
+ output = gr.Textbox(
138
+ label="Generated Caption",
139
+ lines=8,
140
+ max_lines=15,
141
+ show_copy_button=True
142
+ )
143
+
144
+ # Event handlers
145
+ generate_btn.click(
146
+ fn=caption_image,
147
+ inputs=[image_input, custom_prompt],
148
+ outputs=output
149
+ )
150
+
151
+ # Also generate on image upload if no custom prompt
152
+ image_input.change(
153
+ fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None,
154
+ inputs=[image_input, custom_prompt],
155
+ outputs=output
156
+ )
157
+
158
+ gr.Markdown(
159
+ """
160
+ ---
161
+ **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B)
162
+
163
+ **Note:** This Space uses ZeroGPU for dynamic GPU allocation.
164
+ """
165
+ )
166
+
167
+ if __name__ == "__main__":
168
+ demo.launch(
169
+ share=False,
170
+ show_error=True,
171
+ server_name="0.0.0.0",
172
+ server_port=7860
173
+ )