multimodalart HF Staff commited on
Commit
02b00c0
·
verified ·
1 Parent(s): 51b7c7d
Files changed (1) hide show
  1. app.py +117 -88
app.py CHANGED
@@ -139,14 +139,43 @@ def upsample_prompt_logic(prompt, image_list):
139
  print(f"Upsampling failed: {e}")
140
  return prompt
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # Updated duration function to match generate_image arguments (including progress)
143
- def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, force_dimensions, progress=gr.Progress(track_tqdm=True)):
144
  num_images = 0 if image_list is None else len(image_list)
145
  step_duration = 1 + 0.8 * num_images
146
  return max(65, num_inference_steps * step_duration + 10)
147
 
148
  @spaces.GPU(duration=get_duration)
149
- def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, force_dimensions, progress=gr.Progress(track_tqdm=True)):
150
  # Move embeddings to GPU only when inside the GPU decorated function
151
  prompt_embeds = prompt_embeds.to(device)
152
 
@@ -158,12 +187,10 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
158
  "num_inference_steps": num_inference_steps,
159
  "guidance_scale": guidance_scale,
160
  "generator": generator,
 
 
161
  }
162
 
163
- if image_list is None or force_dimensions:
164
- pipe_kwargs["width"] = width
165
- pipe_kwargs["height"] = height
166
-
167
  # Progress bar for the actual generation steps
168
  if progress:
169
  progress(0, desc="Starting generation...")
@@ -171,7 +198,7 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
171
  image = pipe(**pipe_kwargs).images[0]
172
  return image
173
 
174
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, force_dimensions=False, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
175
 
176
  if randomize_seed:
177
  seed = random.randint(0, MAX_SEED)
@@ -206,7 +233,6 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
206
  num_inference_steps,
207
  guidance_scale,
208
  seed,
209
- force_dimensions,
210
  progress
211
  )
212
 
@@ -227,7 +253,7 @@ examples_images = [
227
  css="""
228
  #col-container {
229
  margin: 0 auto;
230
- max-width: 620px;
231
  }
232
  .gallery-container img{
233
  object-fit: contain;
@@ -240,89 +266,85 @@ with gr.Blocks() as demo:
240
  gr.Markdown(f"""# FLUX.2 [dev]
241
  FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
242
  """)
243
-
244
- with gr.Accordion("Input image(s) (optional)", open=True):
245
- input_images = gr.Gallery(
246
- label="Input Image(s)",
247
- type="pil",
248
- columns=3,
249
- rows=1,
250
- )
251
-
252
  with gr.Row():
253
-
254
- prompt = gr.Text(
255
- label="Prompt",
256
- show_label=False,
257
- max_lines=2,
258
- placeholder="Enter your prompt",
259
- container=False,
260
- scale=3
261
- )
262
-
263
- run_button = gr.Button("Run", scale=1)
264
-
265
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- with gr.Accordion("Advanced Settings", open=False):
268
-
269
- prompt_upsampling = gr.Checkbox(
270
- label="Prompt Upsampling",
271
- value=True,
272
- info="Automatically enhance the prompt using a VLM"
273
- )
274
-
275
- seed = gr.Slider(
276
- label="Seed",
277
- minimum=0,
278
- maximum=MAX_SEED,
279
- step=1,
280
- value=0,
281
- )
282
-
283
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
284
-
285
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- width = gr.Slider(
288
- label="Width",
289
- minimum=256,
290
- maximum=MAX_IMAGE_SIZE,
291
- step=32,
292
- value=1024,
293
- )
294
 
295
- height = gr.Slider(
296
- label="Height",
297
- minimum=256,
298
- maximum=MAX_IMAGE_SIZE,
299
- step=32,
300
- value=1024,
301
- )
302
 
303
- force_dimensions = gr.Checkbox(
304
- label="Force width/height when image input",
305
- value=False,
306
- info="When unchecked, width/height settings are ignored if input images are provided"
307
- )
308
-
309
- with gr.Row():
310
-
311
- num_inference_steps = gr.Slider(
312
- label="Number of inference steps",
313
- minimum=1,
314
- maximum=100,
315
- step=1,
316
- value=30,
317
- )
318
-
319
- guidance_scale = gr.Slider(
320
- label="Guidance scale",
321
- minimum=0.0,
322
- maximum=10.0,
323
- step=0.1,
324
- value=4,
325
- )
326
 
327
  gr.Examples(
328
  examples=examples,
@@ -342,10 +364,17 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
342
  cache_mode="lazy"
343
  )
344
 
 
 
 
 
 
 
 
345
  gr.on(
346
  triggers=[run_button.click, prompt.submit],
347
  fn=infer,
348
- inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, force_dimensions, prompt_upsampling],
349
  outputs=[result, seed]
350
  )
351
 
 
139
  print(f"Upsampling failed: {e}")
140
  return prompt
141
 
142
+ def update_dimensions_from_image(image_list):
143
+ """Update width/height sliders based on uploaded image aspect ratio.
144
+ Keeps one side at 1024 and scales the other proportionally, with both sides as multiples of 8."""
145
+ if image_list is None or len(image_list) == 0:
146
+ return 1024, 1024 # Default dimensions
147
+
148
+ # Get the first image to determine dimensions
149
+ img = image_list[0][0] # Gallery returns list of tuples (image, caption)
150
+ img_width, img_height = img.size
151
+
152
+ aspect_ratio = img_width / img_height
153
+
154
+ if aspect_ratio >= 1: # Landscape or square
155
+ new_width = 1024
156
+ new_height = int(1024 / aspect_ratio)
157
+ else: # Portrait
158
+ new_height = 1024
159
+ new_width = int(1024 * aspect_ratio)
160
+
161
+ # Round to nearest multiple of 8
162
+ new_width = round(new_width / 8) * 8
163
+ new_height = round(new_height / 8) * 8
164
+
165
+ # Ensure within valid range (minimum 256, maximum 1024)
166
+ new_width = max(256, min(1024, new_width))
167
+ new_height = max(256, min(1024, new_height))
168
+
169
+ return new_width, new_height
170
+
171
  # Updated duration function to match generate_image arguments (including progress)
172
+ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
173
  num_images = 0 if image_list is None else len(image_list)
174
  step_duration = 1 + 0.8 * num_images
175
  return max(65, num_inference_steps * step_duration + 10)
176
 
177
  @spaces.GPU(duration=get_duration)
178
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
179
  # Move embeddings to GPU only when inside the GPU decorated function
180
  prompt_embeds = prompt_embeds.to(device)
181
 
 
187
  "num_inference_steps": num_inference_steps,
188
  "guidance_scale": guidance_scale,
189
  "generator": generator,
190
+ "width": width,
191
+ "height": height,
192
  }
193
 
 
 
 
 
194
  # Progress bar for the actual generation steps
195
  if progress:
196
  progress(0, desc="Starting generation...")
 
198
  image = pipe(**pipe_kwargs).images[0]
199
  return image
200
 
201
+ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
202
 
203
  if randomize_seed:
204
  seed = random.randint(0, MAX_SEED)
 
233
  num_inference_steps,
234
  guidance_scale,
235
  seed,
 
236
  progress
237
  )
238
 
 
253
  css="""
254
  #col-container {
255
  margin: 0 auto;
256
+ max-width: 1200px;
257
  }
258
  .gallery-container img{
259
  object-fit: contain;
 
266
  gr.Markdown(f"""# FLUX.2 [dev]
267
  FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
268
  """)
 
 
 
 
 
 
 
 
 
269
  with gr.Row():
270
+ with gr.Column():
271
+ with gr.Row():
272
+ prompt = gr.Text(
273
+ label="Prompt",
274
+ show_label=False,
275
+ max_lines=2,
276
+ placeholder="Enter your prompt",
277
+ container=False,
278
+ scale=3
279
+ )
280
+
281
+ run_button = gr.Button("Run", scale=1)
282
+
283
+ with gr.Accordion("Input image(s) (optional)", open=True):
284
+ input_images = gr.Gallery(
285
+ label="Input Image(s)",
286
+ type="pil",
287
+ columns=3,
288
+ rows=1,
289
+ )
290
+
291
+ with gr.Accordion("Advanced Settings", open=False):
292
+ prompt_upsampling = gr.Checkbox(
293
+ label="Prompt Upsampling",
294
+ value=True,
295
+ info="Automatically enhance the prompt using a VLM"
296
+ )
297
 
298
+ seed = gr.Slider(
299
+ label="Seed",
300
+ minimum=0,
301
+ maximum=MAX_SEED,
302
+ step=1,
303
+ value=0,
304
+ )
305
+
306
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
307
+
308
+ with gr.Row():
309
+
310
+ width = gr.Slider(
311
+ label="Width",
312
+ minimum=256,
313
+ maximum=MAX_IMAGE_SIZE,
314
+ step=8,
315
+ value=1024,
316
+ )
317
+
318
+ height = gr.Slider(
319
+ label="Height",
320
+ minimum=256,
321
+ maximum=MAX_IMAGE_SIZE,
322
+ step=8,
323
+ value=1024,
324
+ )
325
+
326
+ with gr.Row():
327
+
328
+ num_inference_steps = gr.Slider(
329
+ label="Number of inference steps",
330
+ minimum=1,
331
+ maximum=100,
332
+ step=1,
333
+ value=30,
334
+ )
335
+
336
+ guidance_scale = gr.Slider(
337
+ label="Guidance scale",
338
+ minimum=0.0,
339
+ maximum=10.0,
340
+ step=0.1,
341
+ value=4,
342
+ )
343
 
 
 
 
 
 
 
 
344
 
345
+ with gr.Column():
346
+ result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  gr.Examples(
350
  examples=examples,
 
364
  cache_mode="lazy"
365
  )
366
 
367
+ # Auto-update dimensions when images are uploaded
368
+ input_images.upload(
369
+ fn=update_dimensions_from_image,
370
+ inputs=[input_images],
371
+ outputs=[width, height]
372
+ )
373
+
374
  gr.on(
375
  triggers=[run_button.click, prompt.submit],
376
  fn=infer,
377
+ inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
378
  outputs=[result, seed]
379
  )
380