Spaces:
Running
on
Zero
Running
on
Zero
File size: 19,829 Bytes
a42ebba 699dc45 a42ebba 699dc45 a42ebba 897cf7f a42ebba 897cf7f dace734 897cf7f a42ebba f7c8d8c a42ebba f7c8d8c a42ebba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 |
import os
import cv2
import numpy as np
import random
import sys
import subprocess
from typing import Sequence, Mapping, Any, Union
import torch
from tqdm import tqdm
import argparse
import json
import logging
import shutil
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
import time
import traceback
from utils import get_path_after_pexel
LOCAL_GRADIO_TMP = os.path.abspath("./gradio_tmp")
os.makedirs(LOCAL_GRADIO_TMP, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = LOCAL_GRADIO_TMP
HF_REPOS = {
"QingyanBai/Ditto_models": ["models_comfy/ditto_global_comfy.safetensors"],
"Kijai/WanVideo_comfy": [
"Wan2_1-T2V-14B_fp8_e4m3fn.safetensors",
"Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"Wan2_1_VAE_bf16.safetensors",
"umt5-xxl-enc-bf16.safetensors",
],
}
MODELS_ROOT = os.path.abspath(os.path.join(os.getcwd(), "models"))
PATHS = {
"diffusion_model": os.path.join(MODELS_ROOT, "diffusion_models"),
"vae_wan": os.path.join(MODELS_ROOT, "vae", "wan"),
"loras": os.path.join(MODELS_ROOT, "loras"),
"text_encoders": os.path.join(MODELS_ROOT, "text_encoders"),
}
REQUIRED_FILES = [
("Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", "diffusion_model"),
("ditto_global_comfy.safetensors", "diffusion_model"),
("Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "loras"),
("Wan2_1_VAE_bf16.safetensors", "vae_wan"),
("umt5-xxl-enc-bf16.safetensors", "text_encoders"),
]
def ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def ensure_models() -> None:
for filename, key in REQUIRED_FILES:
target_dir = PATHS[key]
ensure_dir(target_dir)
target_path = os.path.join(target_dir, filename)
ready_flag = os.path.join(target_dir, f"{filename}.READY")
if os.path.exists(target_path) and os.path.getsize(target_path) > 0:
open(ready_flag, "a").close()
continue
repo_id = None
repo_file_path = None
for repo, files in HF_REPOS.items():
for file_path in files:
if filename in file_path:
repo_id = repo
repo_file_path = file_path
break
if repo_id:
break
if repo_id is None:
raise RuntimeError(f"Could not find repository for file: {filename}")
print(f"Downloading {filename} from {repo_id} to {target_dir} ...")
snapshot_download(
repo_id=repo_id,
local_dir=target_dir,
local_dir_use_symlinks=False,
allow_patterns=[repo_file_path],
token=os.getenv("HF_TOKEN", None),
)
if not os.path.exists(target_path):
found = []
for root, _, files in os.walk(target_dir):
for f in files:
if f == filename:
found.append(os.path.join(root, f))
if found:
src = found[0]
if src != target_path:
shutil.copy2(src, target_path)
if not os.path.exists(target_path):
raise RuntimeError(f"Failed to download required file: {filename}")
open(ready_flag, "w").close()
print(f"Downloaded and ready: {target_path}")
ensure_models()
def ensure_t5_tokenizer() -> None:
"""
Ensure the local T5 tokenizer folder exists and contains valid files.
If missing or corrupted, download from 'google/umt5-xxl' and save locally
to the exact path expected by the WanVideo wrapper nodes.
"""
try:
script_directory = os.path.dirname(os.path.abspath(__file__))
tokenizer_dir = os.path.join(
script_directory,
"custom_nodes",
"ComfyUI_WanVideoWrapper",
"configs",
"T5_tokenizer",
)
os.makedirs(tokenizer_dir, exist_ok=True)
required_files = [
"tokenizer.json",
"tokenizer_config.json",
"spiece.model",
"special_tokens_map.json",
]
def is_valid(path: str) -> bool:
return os.path.exists(path) and os.path.getsize(path) > 0
all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files)
if all_ok:
print(f"T5 tokenizer ready at: {tokenizer_dir}")
return
print(f"Preparing T5 tokenizer at: {tokenizer_dir} ...")
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"google/umt5-xxl",
use_fast=True,
trust_remote_code=False,
)
tok.save_pretrained(tokenizer_dir)
# Re-check
all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files)
if not all_ok:
raise RuntimeError("Tokenizer files not fully prepared after save_pretrained")
print("T5 tokenizer prepared successfully.")
except Exception as e:
print(f"Failed to prepare T5 tokenizer: {e}\n{traceback.format_exc()}")
raise
ensure_t5_tokenizer()
def setup_global_logging_filter():
class MemoryLogFilter(logging.Filter):
def filter(self, record):
msg = record.getMessage()
keywords = [
"Allocated memory:",
"Max allocated memory:",
"Max reserved memory:",
"memory=",
"max_memory=",
"max_reserved=",
"Block swap memory summary",
"Transformer blocks on",
"Total memory used by",
"Non-blocking memory transfer"
]
return not any(kw in msg for kw in keywords)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
force=True
)
logging.getLogger().handlers[0].addFilter(MemoryLogFilter())
setup_global_logging_filter()
def tensor_to_video(video_tensor, output_path, fps=20, crf=20):
frames = video_tensor.detach().cpu().numpy()
if frames.dtype != np.uint8:
if frames.max() <= 1.0:
frames = (frames * 255).astype(np.uint8)
else:
frames = frames.astype(np.uint8)
num_frames, height, width, _ = frames.shape
command = [
'ffmpeg',
'-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'rgb24',
'-s', f'{width}x{height}',
'-r', str(fps),
'-i', '-',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-crf', str(crf),
'-preset', 'medium',
'-r', str(fps),
'-an',
output_path
]
with subprocess.Popen(command, stdin=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
for frame in frames:
proc.stdin.write(frame.tobytes())
proc.stdin.close()
if proc.stderr is not None:
proc.stderr.read()
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
if path is None:
path = os.getcwd()
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = find_path("ComfyUI")
if comfyui_path is not None and os.path.isdir(comfyui_path):
if comfyui_path not in sys.path:
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
try:
from main import load_extra_path_config
except ImportError:
print(
"Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
)
from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths is not None:
load_extra_path_config(extra_model_paths)
else:
print("Could not find the extra_model_paths config file.")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
def import_custom_nodes() -> None:
import asyncio
import execution
from nodes import init_extra_nodes
import server
if getattr(import_custom_nodes, "_initialized", False):
return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
import_custom_nodes._initialized = True
from nodes import NODE_CLASS_MAPPINGS
print(f"Loading custom nodes and models...")
import_custom_nodes()
@spaces.GPU()
def run_pipeline(vpath, prompt, width, height, fps, frame_count, outdir):
try:
import gc
# Clean memory before starting
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
os.makedirs(outdir, exist_ok=True)
with torch.inference_mode():
from custom_nodes.ComfyUI_WanVideoWrapper import nodes as wan_nodes
vhs_loadvideo = NODE_CLASS_MAPPINGS["VHS_LoadVideo"]()
# Set model and settings.
wanvideovacemodelselect = wan_nodes.WanVideoVACEModelSelect()
wanvideovacemodelselect_89 = wanvideovacemodelselect.getvacepath(
vace_model="ditto_global_comfy.safetensors"
)
wanvideoslg = wan_nodes.WanVideoSLG()
wanvideoslg_113 = wanvideoslg.process(
blocks="2",
start_percent=0.20000000000000004,
end_percent=0.7000000000000002,
)
wanvideovaeloader = wan_nodes.WanVideoVAELoader()
wanvideovaeloader_133 = wanvideovaeloader.loadmodel(
model_name="wan/Wan2_1_VAE_bf16.safetensors", precision="bf16"
)
loadwanvideot5textencoder = wan_nodes.LoadWanVideoT5TextEncoder()
loadwanvideot5textencoder_134 = loadwanvideot5textencoder.loadmodel(
model_name="umt5-xxl-enc-bf16.safetensors",
precision="bf16",
load_device="offload_device",
quantization="disabled",
)
wanvideoblockswap = wan_nodes.WanVideoBlockSwap()
wanvideoblockswap_137 = wanvideoblockswap.setargs(
blocks_to_swap=20,
offload_img_emb=False,
offload_txt_emb=False,
use_non_blocking=True,
vace_blocks_to_swap=0,
)
wanvideoloraselect = wan_nodes.WanVideoLoraSelect()
wanvideoloraselect_380 = wanvideoloraselect.getlorapath(
lora="Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
strength=1.0,
low_mem_load=False,
)
wanvideomodelloader = wan_nodes.WanVideoModelLoader()
imageresizekjv2 = NODE_CLASS_MAPPINGS["ImageResizeKJv2"]()
wanvideovaceencode = wan_nodes.WanVideoVACEEncode()
wanvideotextencode = wan_nodes.WanVideoTextEncode()
wanvideosampler = wan_nodes.WanVideoSampler()
wanvideodecode = wan_nodes.WanVideoDecode()
wanvideomodelloader_142 = wanvideomodelloader.loadmodel(
model="Wan2_1-T2V-14B_fp8_e4m3fn.safetensors",
base_precision="fp16",
quantization="disabled",
load_device="offload_device",
attention_mode="sdpa",
block_swap_args=get_value_at_index(wanvideoblockswap_137, 0),
lora=get_value_at_index(wanvideoloraselect_380, 0),
vace_model=get_value_at_index(wanvideovacemodelselect_89, 0),
)
fname = os.path.basename(vpath)
fname_clean = os.path.splitext(fname)[0]
vhs_loadvideo_70 = vhs_loadvideo.load_video(
video=vpath,
force_rate=20,
custom_width=width,
custom_height=height,
frame_load_cap=frame_count,
skip_first_frames=1,
select_every_nth=1,
format="AnimateDiff",
unique_id=16696422174153060213,
)
imageresizekjv2_205 = imageresizekjv2.resize(
width=width,
height=height,
upscale_method="area",
keep_proportion="resize",
pad_color="0, 0, 0",
crop_position="center",
divisible_by=8,
device="cpu",
image=get_value_at_index(vhs_loadvideo_70, 0),
)
wanvideovaceencode_29 = wanvideovaceencode.process(
width=width,
height=height,
num_frames=frame_count,
strength=0.9750000000000002,
vace_start_percent=0,
vace_end_percent=1,
tiled_vae=False,
vae=get_value_at_index(wanvideovaeloader_133, 0),
input_frames=get_value_at_index(imageresizekjv2_205, 0),
)
wanvideotextencode_148 = wanvideotextencode.process(
positive_prompt=prompt,
negative_prompt="flickering artifact, jpg artifacts, compression, distortion, morphing, low-res, fake, oversaturated, overexposed, over bright, strange behavior, distorted limbs, unnatural motion, unrealistic anatomy, glitch, extra limbs,",
force_offload=True,
t5=get_value_at_index(loadwanvideot5textencoder_134, 0),
model_to_offload=get_value_at_index(wanvideomodelloader_142, 0),
)
# Clean memory before sampling (most memory-intensive step)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
wanvideosampler_2 = wanvideosampler.process(
steps=4,
cfg=1.2000000000000002,
shift=2.0000000000000004,
seed=random.randint(1, 2 ** 64),
force_offload=True,
scheduler="unipc",
riflex_freq_index=0,
denoise_strength=1,
batched_cfg=False,
rope_function="comfy",
model=get_value_at_index(wanvideomodelloader_142, 0),
image_embeds=get_value_at_index(wanvideovaceencode_29, 0),
text_embeds=get_value_at_index(wanvideotextencode_148, 0),
slg_args=get_value_at_index(wanvideoslg_113, 0),
)
res = wanvideodecode.decode(
enable_vae_tiling=False,
tile_x=272,
tile_y=272,
tile_stride_x=144,
tile_stride_y=128,
vae=get_value_at_index(wanvideovaeloader_133, 0),
samples=get_value_at_index(wanvideosampler_2, 0),
)
save_path = os.path.join(outdir, f'{fname_clean}_edit.mp4')
tensor_to_video(res[0], save_path, fps=fps)
# Clean up memory after generation
del res
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Done. Saved to: {save_path}")
return save_path
except Exception as e:
err = f"Error: {e}\n{traceback.format_exc()}"
print(err)
# Clean memory on error too
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
@spaces.GPU()
def gradio_infer(vfile, prompt, width, height, fps, frame_count, progress=gr.Progress(track_tqdm=True)):
if vfile is None:
return None, "Please upload the video!", "\n".join(logs)
vpath = vfile if isinstance(vfile, str) else vfile.name
if not os.path.exists(vpath) and hasattr(vfile, "save"):
os.makedirs("uploads", exist_ok=True)
vpath = os.path.join("uploads", os.path.basename(vfile.name))
vfile.save(vpath)
outdir = "results"
os.makedirs(outdir, exist_ok=True)
save_path = run_pipeline(
vpath=vpath,
prompt=prompt,
width=int(width),
height=int(height),
fps=int(fps),
frame_count=int(frame_count),
outdir=outdir,
)
return save_path
def build_interface():
with gr.Blocks(title="Ditto") as demo:
gr.Markdown(
"""# Ditto: Scaling Instruction-Based Video Editing with a High-Quality Synthetic Dataset
<div style="font-size: 1.8rem; line-height: 1.6; margin-bottom: 1rem;">
<a href="https://arxiv.org/abs/2510.15742" target="_blank">📄 Paper</a>
|
<a href="https://ezioby.github.io/Ditto_page/" target="_blank">🌐 Project Page</a>
|
<a href="https://github.com/EzioBy/Ditto/" target="_blank"> 💻 Github Code </a>
|
<a href="https://huggingface.co/QingyanBai/Ditto_models/tree/main" target="_blank">📦 Model Weights</a>
|
<a href="https://huggingface.co/datasets/QingyanBai/Ditto-1M" target="_blank">📊 Dataset</a>
</div>
<b>Note1:</b> The backend of this demo is comfy. Though it runs fast, please note that due to the use of quantized and distilled models, there may be some quality degradation.
<b>Note2:</b> Considering the limited memory, please try test cases with lower resolution and frame count, otherwise it may cause out of memory error (you can also try re-running it).
If you like this project, please consider <a href="https://github.com/EzioBy/Ditto/" target="_blank">starring the repo</a> to motivate us. Thank you!
"""
)
with gr.Column():
with gr.Row():
vfile = gr.Video(label="Input Video", value=os.path.join("input", "dasha.mp4"),
sources="upload", interactive=True)
out_video = gr.Video(label="Result")
prompt = gr.Textbox(label="Editing Instruction", value="Make it in the style of Japanese anime")
with gr.Row():
width = gr.Number(label="Width", value=576, precision=0)
height = gr.Number(label="Height", value=324, precision=0)
fps = gr.Number(label="FPS", value=20, precision=0)
frame_count = gr.Number(label="Frame Count", value=49, precision=0)
run_btn = gr.Button("Run", variant="primary")
run_btn.click(
fn=gradio_infer,
inputs=[vfile, prompt, width, height, fps, frame_count],
outputs=[out_video]
)
examples = [
[
os.path.join("input", "dasha.mp4"),
"Add some fire and flame to the background",
576, 324, 20, 49
],
[
os.path.join("input", "dasha.mp4"),
"Add some snow and flakes to the background",
576, 324, 20, 49
],
[
os.path.join("input", "dasha.mp4"),
"Make it in the style of pencil sketch",
576, 324, 20, 49
],
]
gr.Examples(
examples=examples,
inputs=[vfile, prompt, width, height, fps, frame_count],
label="Examples"
)
return demo
if __name__ == "__main__":
demo = build_interface()
demo.launch()
|