Spaces:
Paused
Paused
| import imageio, os, torch, warnings, torchvision, argparse, json | |
| from peft import LoraConfig, inject_adapter_in_model | |
| from PIL import Image | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from accelerate import Accelerator | |
| import glob, re, math | |
| import time | |
| class ImageDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| base_path=None, metadata_path=None, | |
| max_pixels=1920*1080, height=None, width=None, | |
| height_division_factor=16, width_division_factor=16, | |
| data_file_keys=("image",), | |
| image_file_extension=("jpg", "jpeg", "png", "webp"), | |
| repeat=1, | |
| args=None, | |
| ): | |
| if args is not None: | |
| base_path = args.dataset_base_path | |
| metadata_path = args.dataset_metadata_path | |
| height = args.height | |
| width = args.width | |
| max_pixels = args.max_pixels | |
| data_file_keys = args.data_file_keys.split(",") | |
| repeat = args.dataset_repeat | |
| self.base_path = base_path | |
| self.max_pixels = max_pixels | |
| self.height = height | |
| self.width = width | |
| self.height_division_factor = height_division_factor | |
| self.width_division_factor = width_division_factor | |
| self.data_file_keys = data_file_keys | |
| self.image_file_extension = image_file_extension | |
| self.repeat = repeat | |
| if height is not None and width is not None: | |
| print("Height and width are fixed. Setting `dynamic_resolution` to False.") | |
| self.dynamic_resolution = False | |
| elif height is None and width is None: | |
| print("Height and width are none. Setting `dynamic_resolution` to True.") | |
| self.dynamic_resolution = True | |
| if metadata_path is None: | |
| print("No metadata. Trying to generate it.") | |
| metadata = self.generate_metadata(base_path) | |
| print(f"{len(metadata)} lines in metadata.") | |
| self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
| elif metadata_path.endswith(".json"): | |
| with open(metadata_path, "r") as f: | |
| metadata = json.load(f) | |
| self.data = metadata | |
| else: | |
| metadata = pd.read_csv(metadata_path) | |
| self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
| def generate_metadata(self, folder): | |
| image_list, prompt_list = [], [] | |
| file_set = set(os.listdir(folder)) | |
| for file_name in file_set: | |
| if "." not in file_name: | |
| continue | |
| file_ext_name = file_name.split(".")[-1].lower() | |
| file_base_name = file_name[:-len(file_ext_name)-1] | |
| if file_ext_name not in self.image_file_extension: | |
| continue | |
| prompt_file_name = file_base_name + ".txt" | |
| if prompt_file_name not in file_set: | |
| continue | |
| with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: | |
| prompt = f.read().strip() | |
| image_list.append(file_name) | |
| prompt_list.append(prompt) | |
| metadata = pd.DataFrame() | |
| metadata["image"] = image_list | |
| metadata["prompt"] = prompt_list | |
| return metadata | |
| def crop_and_resize(self, image, target_height, target_width): | |
| width, height = image.size | |
| scale = max(target_width / width, target_height / height) | |
| image = torchvision.transforms.functional.resize( | |
| image, | |
| (round(height*scale), round(width*scale)), | |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
| ) | |
| image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) | |
| return image | |
| def get_height_width(self, image): | |
| if self.dynamic_resolution: | |
| width, height = image.size | |
| if width * height > self.max_pixels: | |
| scale = (width * height / self.max_pixels) ** 0.5 | |
| height, width = int(height / scale), int(width / scale) | |
| height = height // self.height_division_factor * self.height_division_factor | |
| width = width // self.width_division_factor * self.width_division_factor | |
| else: | |
| height, width = self.height, self.width | |
| return height, width | |
| def load_image(self, file_path): | |
| image = Image.open(file_path).convert("RGB") | |
| image = self.crop_and_resize(image, *self.get_height_width(image)) | |
| return image | |
| def load_data(self, file_path): | |
| return self.load_image(file_path) | |
| def __getitem__(self, data_id): | |
| data = self.data[data_id % len(self.data)].copy() | |
| for key in self.data_file_keys: | |
| if key in data: | |
| path = os.path.join(self.base_path, data[key]) | |
| data[key] = self.load_data(path) | |
| if data[key] is None: | |
| warnings.warn(f"cannot load file {data[key]}.") | |
| return None | |
| return data | |
| def __len__(self): | |
| return len(self.data) * self.repeat | |
| class VideoDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| base_path=None, metadata_path=None, | |
| num_frames=81, | |
| time_division_factor=4, time_division_remainder=1, | |
| max_pixels=1920*1080, height=None, width=None, | |
| height_division_factor=16, width_division_factor=16, | |
| data_file_keys=("video_path",), | |
| image_file_extension=("jpg", "jpeg", "png", "webp"), | |
| video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), | |
| repeat=1, | |
| args=None, | |
| ): | |
| if args is not None: | |
| base_path = args.dataset_base_path | |
| metadata_path = args.dataset_metadata_path | |
| height = args.height | |
| width = args.width | |
| max_pixels = args.max_pixels | |
| num_frames = args.num_frames | |
| data_file_keys = args.data_file_keys.split(",") | |
| repeat = args.dataset_repeat | |
| self.base_path = base_path | |
| self.num_frames = num_frames | |
| self.time_division_factor = time_division_factor | |
| self.time_division_remainder = time_division_remainder | |
| self.max_pixels = max_pixels | |
| self.height = height | |
| self.width = width | |
| self.height_division_factor = height_division_factor | |
| self.width_division_factor = width_division_factor | |
| self.data_file_keys = data_file_keys | |
| self.image_file_extension = image_file_extension | |
| self.video_file_extension = video_file_extension | |
| self.repeat = repeat | |
| if height is not None and width is not None: | |
| print("Height and width are fixed. Setting `dynamic_resolution` to False.") | |
| self.dynamic_resolution = False | |
| elif height is None and width is None: | |
| print("Height and width are none. Setting `dynamic_resolution` to True.") | |
| self.dynamic_resolution = True | |
| if metadata_path is None: | |
| print("No metadata. Trying to generate it.") | |
| metadata = self.generate_metadata(base_path) | |
| print(f"{len(metadata)} lines in metadata.") | |
| self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
| elif metadata_path.endswith(".json"): | |
| with open(metadata_path, "r") as f: | |
| metadata = json.load(f) | |
| self.data = metadata | |
| else: | |
| metadata = pd.read_csv(metadata_path) | |
| self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
| def generate_metadata(self, folder): | |
| video_list, prompt_list = [], [] | |
| file_set = set(os.listdir(folder)) | |
| for file_name in file_set: | |
| if "." not in file_name: | |
| continue | |
| file_ext_name = file_name.split(".")[-1].lower() | |
| file_base_name = file_name[:-len(file_ext_name)-1] | |
| if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension: | |
| continue | |
| prompt_file_name = file_base_name + ".txt" | |
| if prompt_file_name not in file_set: | |
| continue | |
| with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: | |
| prompt = f.read().strip() | |
| video_list.append(file_name) | |
| prompt_list.append(prompt) | |
| metadata = pd.DataFrame() | |
| metadata["video"] = video_list | |
| metadata["prompt"] = prompt_list | |
| return metadata | |
| def crop_and_resize(self, image, target_height, target_width): | |
| width, height = image.size | |
| scale = max(target_width / width, target_height / height) | |
| image = torchvision.transforms.functional.resize( | |
| image, | |
| (round(height*scale), round(width*scale)), | |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
| ) | |
| image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) | |
| return image | |
| def get_height_width(self, image): | |
| if self.dynamic_resolution: | |
| width, height = image.size | |
| if width * height > self.max_pixels: | |
| scale = (width * height / self.max_pixels) ** 0.5 | |
| height, width = int(height / scale), int(width / scale) | |
| height = height // self.height_division_factor * self.height_division_factor | |
| width = width // self.width_division_factor * self.width_division_factor | |
| else: | |
| height, width = self.height, self.width | |
| return height, width | |
| def get_num_frames(self, reader): | |
| num_frames = self.num_frames | |
| if int(reader.count_frames()) < num_frames: | |
| num_frames = int(reader.count_frames()) | |
| while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: | |
| num_frames -= 1 | |
| return num_frames | |
| def load_video(self, file_path): | |
| reader = imageio.get_reader(file_path) | |
| num_frames = self.get_num_frames(reader) | |
| frames = [] | |
| for frame_id in range(num_frames): | |
| frame = reader.get_data(frame_id) | |
| frame = Image.fromarray(frame) | |
| frame = self.crop_and_resize(frame, *self.get_height_width(frame)) | |
| frames.append(frame) | |
| frames.append(frame) | |
| frames.append(frame) | |
| frames.append(frame) | |
| frames=frames[:-3] | |
| reader.close() | |
| return frames | |
| def load_json(self, file_path): | |
| with open(file_path, 'r') as f: | |
| json_data = json.load(f) | |
| num_shots = json_data["num_shots"] | |
| shot_cut_frames = json_data["shot_cut_frames"] | |
| return num_shots, shot_cut_frames | |
| def load_image(self, file_path): | |
| image = Image.open(file_path).convert("RGB") | |
| image = self.crop_and_resize(image, *self.get_height_width(image)) | |
| frames = [image] | |
| return frames | |
| def is_image(self, file_path): | |
| file_ext_name = file_path.split(".")[-1] | |
| return file_ext_name.lower() in self.image_file_extension | |
| def is_video(self, file_path): | |
| file_ext_name = file_path.split(".")[-1] | |
| return file_ext_name.lower() in self.video_file_extension | |
| def is_json(self, file_path): | |
| file_ext_name = file_path.split(".")[-1] | |
| return file_ext_name.lower() == "json" | |
| def load_data(self, file_path): | |
| if self.is_image(file_path): | |
| return self.load_image(file_path) | |
| elif self.is_video(file_path): | |
| return self.load_video(file_path) | |
| elif self.is_json(file_path): | |
| return self.load_json(file_path) | |
| else: | |
| return None | |
| def __getitem__(self, data_id): | |
| data = self.data[data_id % len(self.data)].copy() | |
| for key in self.data_file_keys: | |
| if key in data: | |
| path = os.path.join(self.base_path, data[key]) | |
| data[key] = self.load_data(path) | |
| if data[key] is None: | |
| warnings.warn(f"cannot load file {data[key]}.") | |
| return None | |
| if key=="json_path": | |
| num_shots, shot_cut_frames = self.load_json(path) | |
| data["num_shots"] = num_shots | |
| data["shot_cut_frames"] = shot_cut_frames | |
| return data | |
| def __len__(self): | |
| return len(self.data) * self.repeat | |
| class DiffusionTrainingModule(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def to(self, *args, **kwargs): | |
| for name, model in self.named_children(): | |
| model.to(*args, **kwargs) | |
| return self | |
| def trainable_modules(self): | |
| trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) | |
| return trainable_modules | |
| def trainable_param_names(self): | |
| trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters())) | |
| trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) | |
| return trainable_param_names | |
| def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): | |
| if lora_alpha is None: | |
| lora_alpha = lora_rank | |
| lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) | |
| model = inject_adapter_in_model(lora_config, model) | |
| return model | |
| def export_trainable_state_dict(self, state_dict, remove_prefix=None): | |
| trainable_param_names = self.trainable_param_names() | |
| state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names} | |
| if remove_prefix is not None: | |
| state_dict_ = {} | |
| for name, param in state_dict.items(): | |
| if name.startswith(remove_prefix): | |
| name = name[len(remove_prefix):] | |
| state_dict_[name] = param | |
| state_dict = state_dict_ | |
| return state_dict | |
| class ModelLogger: | |
| def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x, validation_config=None, save_every_n_steps=1000): | |
| self.output_path = output_path | |
| self.remove_prefix_in_ckpt = remove_prefix_in_ckpt | |
| self.state_dict_converter = state_dict_converter | |
| self.validation_config = validation_config | |
| self.save_every_n_steps = save_every_n_steps | |
| # Create subdirectories for clarity | |
| self.resumable_path = os.path.join(output_path, "resumable") | |
| self.portable_path = os.path.join(output_path, "portable") | |
| self.validation_path = os.path.join(output_path, "validation") | |
| os.makedirs(self.resumable_path, exist_ok=True) | |
| os.makedirs(self.portable_path, exist_ok=True) | |
| os.makedirs(self.validation_path, exist_ok=True) | |
| def on_step_end(self, accelerator, model, global_step): | |
| self.save_model(accelerator, model, f"step-{global_step}") | |
| def on_epoch_end(self, accelerator, model, epoch_id): | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| state_dict = accelerator.get_state_dict(model) | |
| state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) | |
| state_dict = self.state_dict_converter(state_dict) | |
| os.makedirs(self.output_path, exist_ok=True) | |
| path = os.path.join(self.output_path, f"epoch-{epoch_id}") | |
| accelerator.save(state_dict, path, safe_serialization=True) | |
| # def save_model(self, accelerator, model, file_name): | |
| # accelerator.wait_for_everyone() | |
| # # if accelerator.is_main_process: | |
| # # state_dict = accelerator.get_state_dict(model) | |
| # # state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) | |
| # # state_dict = self.state_dict_converter(state_dict) | |
| # # os.makedirs(self.output_path, exist_ok=True) | |
| # # path = os.path.join(self.output_path, file_name) | |
| # # accelerator.save(state_dict, path, safe_serialization=True) | |
| # os.makedirs(self.output_path, exist_ok=True) | |
| # path = os.path.join(self.output_path, file_name) | |
| # accelerator.save_state(path) | |
| # accelerator.wait_for_everyone() | |
| def save_model(self, accelerator, model, file_name): | |
| accelerator.wait_for_everyone() | |
| path = os.path.join(self.output_path, file_name) | |
| if accelerator.is_main_process: | |
| os.makedirs(path, exist_ok=True) | |
| accelerator.wait_for_everyone() | |
| accelerator.save_state(path) | |
| def launch_training_task( | |
| dataset: torch.utils.data.Dataset, | |
| model: DiffusionTrainingModule, | |
| model_logger: ModelLogger, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| num_epochs: int = 1, | |
| gradient_accumulation_steps: int = 1, | |
| resume_from_checkpoint: str = None, | |
| save_every_n_steps: int = 1000, | |
| mixed_precision: str = "bf16", | |
| enabled_deepspeed: bool = False, | |
| model_ds_config: str = None, | |
| ): | |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True,num_workers=1, collate_fn=lambda x: x[0]) | |
| if enabled_deepspeed: | |
| from accelerate.utils import DeepSpeedPlugin | |
| accelerator = Accelerator( | |
| deepspeed_plugins=DeepSpeedPlugin(hf_ds_config=model_ds_config), | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| ) | |
| if accelerator.is_main_process: | |
| print("Setting up deepspeed zero2 optimization done.") | |
| else: | |
| accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) | |
| model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) | |
| # Resuming logic | |
| global_step = 0 | |
| # if resume_from_checkpoint: | |
| # print(f"Resuming from checkpoint: {resume_from_checkpoint}") | |
| # accelerator.load_state(resume_from_checkpoint) | |
| # try: | |
| # path = os.path.basename(resume_from_checkpoint) | |
| # global_step = int(re.search(r"step-(\d+)", path).group(1)) | |
| # print(f"Restored global_step to {global_step}") | |
| # except (AttributeError, ValueError): | |
| # print("Could not parse global_step from checkpoint path. Starting from 0.") | |
| # global_step = 0 | |
| if resume_from_checkpoint is not None: | |
| accelerator.load_state(resume_from_checkpoint) | |
| global_step = int(re.search(r"step-(\d+)", resume_from_checkpoint).group(1)) | |
| num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) | |
| starting_epoch = global_step // num_update_steps_per_epoch if num_update_steps_per_epoch > 0 else 0 | |
| for epoch_id in range(starting_epoch, num_epochs): | |
| for data in tqdm(dataloader, desc=f"Epoch {epoch_id}"): | |
| with accelerator.accumulate(model): | |
| optimizer.zero_grad() | |
| loss = model(data) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| scheduler.step() | |
| if accelerator.sync_gradients: | |
| if global_step == 0 or global_step % save_every_n_steps == 0: | |
| model_logger.on_step_end(accelerator, model, global_step) | |
| global_step += 1 | |
| # model_logger.on_epoch_end(accelerator, model, epoch_id) | |
| def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): | |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, num_workers=1, collate_fn=lambda x: x[0]) | |
| accelerator = Accelerator() | |
| model, dataloader = accelerator.prepare(model, dataloader) | |
| os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) | |
| for data_id, data in enumerate(tqdm(dataloader)): | |
| with torch.no_grad(): | |
| inputs = model.forward_preprocess(data) | |
| inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} | |
| torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) | |
| def wan_parser(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") | |
| parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") | |
| parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..") | |
| parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") | |
| parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") | |
| parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") | |
| parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") | |
| parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") | |
| parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") | |
| parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") | |
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") | |
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") | |
| parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") | |
| parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") | |
| parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") | |
| parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") | |
| parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") | |
| parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") | |
| parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") | |
| parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") | |
| parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") | |
| parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") | |
| parser.add_argument("--save_every_n_steps", type=int, default=1000, help="Save a checkpoint every N steps.") | |
| parser.add_argument("--validation_config_path", type=str, default=None, help="Path to a JSON file containing a list of validation configurations.") | |
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to a checkpoint folder to resume from. Set to 'latest' to automatically find the latest.") | |
| parser.add_argument("--max_shots", type=int, default=20, help="Max shots.") | |
| parser.add_argument("--use_shot_embedding", type=bool, default=False, help="Use shot embedding.") | |
| parser.add_argument("--shot_embedding_init", type=str, default="normal", help="Shot embedding initialization method.") | |
| parser.add_argument("--enabled_deepspeed", default=False, action="store_true", help="Whether to use deepspeed.") | |
| parser.add_argument("--model_ds_config", type=str, default=None, help="Path to a DeepSpeed config file.") | |
| parser.add_argument("--mixed_precision", type=str, default="bf16", help="Mixed precision.") | |
| # parser.add_argument("--use_shot_mask", default=False, action="store_true", help="Whether to use shot mask.") | |
| # parser.add_argument("--shot_mask_type", type=str, default="id", choices=[None,"id", "normalized", "alternating"]) | |
| return parser | |
| def flux_parser(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") | |
| parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") | |
| parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..") | |
| parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") | |
| parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") | |
| parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.") | |
| parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") | |
| parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") | |
| parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") | |
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") | |
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") | |
| parser.add_argument("--output_path", type=str, default="./models", help="Output save path.") | |
| parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") | |
| parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") | |
| parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.") | |
| parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") | |
| parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") | |
| parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") | |
| parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") | |
| parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") | |
| parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") | |
| return parser | |