halocine-demoz / diffsynth /pipelines /flux_image_new.py
Yhmeng1106's picture
Initial commit
c1a26d7
import torch, warnings, glob, os, types
import numpy as np
from PIL import Image
from einops import repeat, reduce
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from ..schedulers import FlowMatchScheduler
from ..prompters import FluxPrompter
from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.flux_lora_encoder import FluxLoRAEncoder, LoRALayerBlock
from ..models.tiler import FastTileWorker
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader, FluxLoraPatcher, FluxLoRAFuser
from ..models.flux_dit import RMSNorm
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
@dataclass
class ControlNetInput:
controlnet_id: int = 0
scale: float = 1.0
start: float = 1.0
end: float = 0.0
image: Image.Image = None
inpaint_mask: Image.Image = None
processor_id: str = None
class MultiControlNet(torch.nn.Module):
def __init__(self, models: list[FluxControlNet]):
super().__init__()
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack, single_res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
single_res_stack = [res * controlnet_input.scale for res in single_res_stack]
return res_stack, single_res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack, single_res_stack = None, None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start or progress < controlnet_input.end:
continue
res_stack_, single_res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
single_res_stack = single_res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
return res_stack, single_res_stack
class FluxImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler()
self.prompter = FluxPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
self.dit: FluxDiT = None
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: MultiControlNet = None
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder = None
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.value_controller: MultiValueEncoder = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.lora_patcher: FluxLoraPatcher = None
self.lora_encoder: FluxLoRAEncoder = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "step1x_connector", "controlnet", "lora_patcher")
self.units = [
FluxImageUnit_ShapeChecker(),
FluxImageUnit_NoiseInitializer(),
FluxImageUnit_PromptEmbedder(),
FluxImageUnit_InputImageEmbedder(),
FluxImageUnit_ImageIDs(),
FluxImageUnit_EmbeddedGuidanceEmbedder(),
FluxImageUnit_Kontext(),
FluxImageUnit_InfiniteYou(),
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
FluxImageUnit_EntityControl(),
FluxImageUnit_TeaCache(),
FluxImageUnit_Flex(),
FluxImageUnit_Step1x(),
FluxImageUnit_ValueControl(),
FluxImageUnit_LoRAEncode(),
]
self.model_fn = model_fn_flux_image
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = loader.convert_state_dict(lora)
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader.load(module, lora, alpha=alpha)
def load_loras(
self,
module: torch.nn.Module,
lora_configs: list[Union[ModelConfig, str]],
alpha=1,
hotload=False,
extra_fused_lora=False,
):
for lora_config in lora_configs:
self.load_lora(module, lora_config, hotload=hotload, alpha=alpha)
if extra_fused_lora:
lora_fuser = FluxLoRAFuser(device="cuda", torch_dtype=torch.bfloat16)
fused_lora = lora_fuser(lora_configs)
self.load_lora(module, state_dict=fused_lora, hotload=hotload, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def _enable_vram_management_with_default_config(self, model, vram_limit):
if model is not None:
dtype = next(iter(model.parameters())).dtype
enable_vram_management(
model,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
RMSNorm: AutoWrappedModule,
LoRALayerBlock: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
if self.lora_patcher is not None:
for name, module in self.dit.named_modules():
if isinstance(module, AutoWrappedLinear):
merger_name = name.replace(".", "___")
if merger_name in self.lora_patcher.model_dict:
module.lora_merger = self.lora_patcher.model_dict[merger_name]
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
# Default config
default_vram_management_models = ["text_encoder_1", "vae_decoder", "vae_encoder", "controlnet", "image_proj_model", "ipadapter", "lora_patcher", "value_controller", "step1x_connector", "lora_encoder"]
for model_name in default_vram_management_models:
self._enable_vram_management_with_default_config(getattr(self, model_name), vram_limit)
# Special config
if self.text_encoder_2 is not None:
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.ipadapter_image_encoder is not None:
from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings, SiglipEncoder, SiglipMultiheadAttentionPoolingHead
dtype = next(iter(self.ipadapter_image_encoder.parameters())).dtype
enable_vram_management(
self.ipadapter_image_encoder,
module_map = {
SiglipVisionEmbeddings: AutoWrappedModule,
SiglipEncoder: AutoWrappedModule,
SiglipMultiheadAttentionPoolingHead: AutoWrappedModule,
torch.nn.MultiheadAttention: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.qwenvl is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionPatchEmbed, Qwen2_5_VLVisionBlock, Qwen2_5_VLPatchMerger,
Qwen2_5_VLDecoderLayer, Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
)
dtype = next(iter(self.qwenvl.parameters())).dtype
enable_vram_management(
self.qwenvl,
module_map = {
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VLVisionBlock: AutoWrappedModule,
Qwen2_5_VLPatchMerger: AutoWrappedModule,
Qwen2_5_VLDecoderLayer: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = FluxImagePipeline(device=device, torch_dtype=torch_dtype)
pipe.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
pipe.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
pipe.dit = model_manager.fetch_model("flux_dit")
pipe.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
pipe.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2)
pipe.ipadapter = model_manager.fetch_model("flux_ipadapter")
pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
pipe.qwenvl = model_manager.fetch_model("qwenvl")
pipe.step1x_connector = model_manager.fetch_model("step1x_connector")
pipe.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
if pipe.image_proj_model is not None:
pipe.infinityou_processor = InfinitYou(device=device)
pipe.lora_patcher = model_manager.fetch_model("flux_lora_patcher")
pipe.lora_encoder = model_manager.fetch_model("flux_lora_encoder")
# ControlNet
controlnets = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_controlnet":
controlnets.append(model)
if len(controlnets) > 0:
pipe.controlnet = MultiControlNet(controlnets)
# Value Controller
value_controllers = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_value_controller":
value_controllers.append(model)
if len(value_controllers) > 0:
pipe.value_controller = MultiValueEncoder(value_controllers)
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
t5_sequence_length: int = 512,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Scheduler
sigma_shift: float = None,
# Steps
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# IP-Adapter
ipadapter_images: Union[list[Image.Image], Image.Image] = None,
ipadapter_scale: float = 1.0,
# EliGen
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
eligen_enable_inpaint: bool = False,
# InfiniteYou
infinityou_id_image: Image.Image = None,
infinityou_guidance: float = 1.0,
# Flex
flex_inpaint_image: Image.Image = None,
flex_inpaint_mask: Image.Image = None,
flex_control_image: Image.Image = None,
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Value Controller
value_controller_inputs: Union[list[float], float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# LoRA Encoder
lora_encoder_inputs: Union[list[ModelConfig], ModelConfig, str] = None,
lora_encoder_scale: float = 1.0,
# TeaCache
tea_cache_l1_thresh: float = None,
# Tile
tiled: bool = False,
tile_size: int = 128,
tile_stride: int = 64,
# Progress bar
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"negative_prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "t5_sequence_length": t5_sequence_length,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
"value_controller_inputs": value_controller_inputs,
"step1x_reference_image": step1x_reference_image,
"lora_encoder_inputs": lora_encoder_inputs, "lora_encoder_scale": lora_encoder_scale,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"progress_bar_cmd": progress_bar_cmd,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
# Decode
self.load_models_to_device(['vae_decoder'])
image = self.vae_decoder(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class FluxImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width"))
def process(self, pipe: FluxImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class FluxImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device"))
def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)
return {"noise": noise}
class FluxImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: FluxImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae_encoder'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": None}
class FluxImageUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("t5_sequence_length",),
onload_model_names=("text_encoder_1", "text_encoder_2")
)
def process(self, pipe: FluxImagePipeline, prompt, t5_sequence_length, positive) -> dict:
if pipe.text_encoder_1 is not None and pipe.text_encoder_2 is not None:
prompt_emb, pooled_prompt_emb, text_ids = pipe.prompter.encode_prompt(
prompt, device=pipe.device, positive=positive, t5_sequence_length=t5_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
else:
return {}
class FluxImageUnit_ImageIDs(PipelineUnit):
def __init__(self):
super().__init__(input_params=("latents",))
def process(self, pipe: FluxImagePipeline, latents):
latent_image_ids = pipe.dit.prepare_image_ids(latents)
return {"image_ids": latent_image_ids}
class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
def __init__(self):
super().__init__(input_params=("embedded_guidance", "latents"))
def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):
guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"guidance": guidance}
class FluxImageUnit_Kontext(PipelineUnit):
def __init__(self):
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
if kontext_images is None:
return {}
if not isinstance(kontext_images, list):
kontext_images = [kontext_images]
kontext_latents = []
kontext_image_ids = []
for kontext_image in kontext_images:
kontext_image = pipe.preprocess_image(kontext_image)
kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image_ids = pipe.dit.prepare_image_ids(kontext_latent)
image_ids[..., 0] = 1
kontext_image_ids.append(image_ids)
kontext_latent = pipe.dit.patchify(kontext_latent)
kontext_latents.append(kontext_latent)
kontext_latents = torch.concat(kontext_latents, dim=1)
kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)
return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids}
class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: FluxImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
pipe.load_models_to_device(['vae_encoder'])
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {"controlnet_conditionings": conditionings}
class FluxImageUnit_IPAdapter(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("ipadapter_image_encoder", "ipadapter")
)
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0)
if ipadapter_images is None:
return inputs_shared, inputs_posi, inputs_nega
if not isinstance(ipadapter_images, list):
ipadapter_images = [ipadapter_images]
pipe.load_models_to_device(self.onload_model_names)
images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images]
images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images]
ipadapter_images = torch.cat(images, dim=0)
ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output
inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)})
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))})
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("text_encoder_1", "text_encoder_2")
)
def preprocess_masks(self, pipe, masks, height, width, dim):
out_masks = []
for mask in masks:
mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype)
out_masks.append(mask)
return out_masks
def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512):
entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1)
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
prompt_emb, _, _ = pipe.prompter.encode_prompt(
entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length
)
return prompt_emb.unsqueeze(0), entity_masks
def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale):
entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length)
if enable_eligen_on_negative and cfg_scale != 1.0:
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
entity_masks_nega = entity_masks_posi
else:
entity_prompt_emb_nega, entity_masks_nega = None, None
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
return eligen_kwargs_posi, eligen_kwargs_nega
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None)
if eligen_entity_prompts is None or eligen_entity_masks is None:
return inputs_shared, inputs_posi, inputs_nega
pipe.load_models_to_device(self.onload_model_names)
eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega,
eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"],
inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"])
inputs_posi.update(eligen_kwargs_posi)
if inputs_shared.get("cfg_scale", 1.0) != 1.0:
inputs_nega.update(eligen_kwargs_nega)
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_Step1x(PipelineUnit):
def __init__(self):
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
image = inputs_shared.get("step1x_reference_image",None)
if image is None:
return inputs_shared, inputs_posi, inputs_nega
else:
pipe.load_models_to_device(self.onload_model_names)
prompt = inputs_posi["prompt"]
nega_prompt = inputs_nega["negative_prompt"]
captions = [prompt, nega_prompt]
ref_images = [image, image]
embs, masks = pipe.qwenvl(captions, ref_images)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae_encoder(image)
inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image})
if inputs_shared.get("cfg_scale", 1) != 1:
inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image})
return inputs_shared, inputs_posi, inputs_nega
class FluxImageUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
if tea_cache_l1_thresh is None:
return {}
else:
return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)}
class FluxImageUnit_Flex(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
if pipe.dit.input_dim == 196:
if flex_control_stop is None:
flex_control_stop = 1
pipe.load_models_to_device(self.onload_model_names)
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
else:
flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype)
flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if flex_inpaint_mask is None:
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
else:
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype)
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype)
flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))]
return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
else:
return {}
class FluxImageUnit_InfiniteYou(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("infinityou_id_image", "infinityou_guidance"),
onload_model_names=("infinityou_processor",)
)
def process(self, pipe: FluxImagePipeline, infinityou_id_image, infinityou_guidance):
pipe.load_models_to_device("infinityou_processor")
if infinityou_id_image is not None:
return pipe.infinityou_processor.prepare_infinite_you(pipe.image_proj_model, infinityou_id_image, infinityou_guidance, pipe.device)
else:
return {}
class FluxImageUnit_ValueControl(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",),
onload_model_names=("value_controller",)
)
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
if value_controller_inputs is None:
return {}
if not isinstance(value_controller_inputs, list):
value_controller_inputs = [value_controller_inputs]
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
value_emb = value_emb.unsqueeze(0)
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
self.device = device
self.torch_dtype = torch_dtype
insightface_root_path = 'models/ByteDance/InfiniteYou/supports/insightface'
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
self.arcface_model = init_recognition_model('arcface', device=self.device).to(torch_dtype)
def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
face_info = self.app_160.get(id_image_cv2)
return face_info
def extract_arcface_bgr_embedding(self, in_image, landmark, device):
from insightface.utils import face_align
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.contiguous().to(device=device, dtype=self.torch_dtype)
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
return face_emb
def prepare_infinite_you(self, model, id_image, infinityou_guidance, device):
import cv2
if id_image is None:
return {'id_emb': None}
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark, device)
id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=device, dtype=self.torch_dtype)
return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}
class FluxImageUnit_LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("lora_encoder",)
)
def parse_lora_encoder_inputs(self, lora_encoder_inputs):
if not isinstance(lora_encoder_inputs, list):
lora_encoder_inputs = [lora_encoder_inputs]
lora_configs = []
for lora_encoder_input in lora_encoder_inputs:
if isinstance(lora_encoder_input, str):
lora_encoder_input = ModelConfig(path=lora_encoder_input)
lora_encoder_input.download_if_necessary()
lora_configs.append(lora_encoder_input)
return lora_configs
def load_lora(self, lora_config, dtype, device):
loader = FluxLoRALoader(torch_dtype=dtype, device=device)
lora = load_state_dict(lora_config.path, torch_dtype=dtype, device=device)
lora = loader.convert_state_dict(lora)
return lora
def lora_embedding(self, pipe, lora_encoder_inputs):
lora_emb = []
for lora_config in self.parse_lora_encoder_inputs(lora_encoder_inputs):
lora = self.load_lora(lora_config, pipe.torch_dtype, pipe.device)
lora_emb.append(pipe.lora_encoder(lora))
lora_emb = torch.concat(lora_emb, dim=1)
return lora_emb
def add_to_text_embedding(self, prompt_emb, text_ids, lora_emb):
prompt_emb = torch.concat([prompt_emb, lora_emb], dim=1)
extra_text_ids = torch.zeros((lora_emb.shape[0], lora_emb.shape[1], 3), device=lora_emb.device, dtype=lora_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("lora_encoder_inputs", None) is None:
return inputs_shared, inputs_posi, inputs_nega
# Encode
pipe.load_models_to_device(["lora_encoder"])
lora_encoder_inputs = inputs_shared["lora_encoder_inputs"]
lora_emb = self.lora_embedding(pipe, lora_encoder_inputs)
# Scale
lora_encoder_scale = inputs_shared.get("lora_encoder_scale", None)
if lora_encoder_scale is not None:
lora_emb = lora_emb * lora_encoder_scale
# Add to prompt embedding
inputs_posi["prompt_emb"], inputs_posi["text_ids"] = self.add_to_text_embedding(
inputs_posi["prompt_emb"], inputs_posi["text_ids"], lora_emb)
return inputs_shared, inputs_posi, inputs_nega
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: FluxDiT, hidden_states, conditioning):
inp = hidden_states.clone()
temb_ = conditioning.clone()
modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = hidden_states.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def model_fn_flux_image(
dit: FluxDiT,
controlnet=None,
step1x_connector=None,
latents=None,
timestep=None,
prompt_emb=None,
pooled_prompt_emb=None,
guidance=None,
text_ids=None,
image_ids=None,
kontext_latents=None,
kontext_image_ids=None,
controlnet_inputs=None,
controlnet_conditionings=None,
tiled=False,
tile_size=128,
tile_stride=64,
entity_prompt_emb=None,
entity_masks=None,
ipadapter_kwargs_list={},
id_emb=None,
infinityou_guidance=None,
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
step1x_llm_embedding=None,
step1x_mask=None,
step1x_reference_latents=None,
tea_cache: TeaCache = None,
progress_id=0,
num_inference_steps=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_conditionings = [f[:, :, hl: hr, wl: wr] for f in controlnet_conditionings] if controlnet_conditionings is not None else None
return model_fn_flux_image(
dit=dit,
controlnet=controlnet,
latents=latents[:, :, hl: hr, wl: wr],
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=pooled_prompt_emb,
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_inputs=controlnet_inputs,
controlnet_conditionings=tiled_controlnet_conditionings,
tiled=False,
**kwargs
)
return FastTileWorker().tiled_forward(
flux_forward_fn,
latents,
tile_size=tile_size,
tile_stride=tile_stride,
tile_device=latents.device,
tile_dtype=latents.dtype
)
hidden_states = latents
# ControlNet
if controlnet is not None and controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
"progress_id": progress_id,
"num_inference_steps": num_inference_steps,
}
if id_emb is not None:
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_conditionings, **controlnet_extra_kwargs
)
# Flex
if flex_condition is not None:
if timestep.tolist()[0] >= flex_control_stop_timestep:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
else:
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
# Step1x
if step1x_llm_embedding is not None:
prompt_emb, pooled_prompt_emb = step1x_connector(step1x_llm_embedding, timestep / 1000, step1x_mask)
text_ids = torch.zeros((1, prompt_emb.shape[1], 3), dtype=prompt_emb.dtype, device=prompt_emb.device)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)
conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
# Kontext
if kontext_latents is not None:
image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)
hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)
# Step1x
if step1x_reference_latents is not None:
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
step1x_reference_latents = dit.patchify(step1x_reference_latents)
image_ids = torch.concat([image_ids, step1x_reference_image_ids], dim=-2)
hidden_states = torch.concat([hidden_states, step1x_reference_latents], dim=1)
hidden_states = dit.x_embedder(hidden_states)
# EliGen
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
else:
tea_cache_update = False
if tea_cache_update:
hidden_states = tea_cache.update(hidden_states)
else:
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None),
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
if kontext_latents is None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
else:
hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
if kontext_latents is None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
else:
hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
if tea_cache is not None:
tea_cache.store(hidden_states)
hidden_states = dit.final_norm_out(hidden_states, conditioning)
hidden_states = dit.final_proj_out(hidden_states)
# Step1x
if step1x_reference_latents is not None:
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
# Kontext
if kontext_latents is not None:
hidden_states = hidden_states[:, :-kontext_latents.shape[1]]
hidden_states = dit.unpatchify(hidden_states, height, width)
return hidden_states