import torch from .sd_text_encoder import CLIPEncoderLayer class LoRALayerBlock(torch.nn.Module): def __init__(self, L, dim_in, dim_out): super().__init__() self.x = torch.nn.Parameter(torch.randn(1, L, dim_in)) self.layer_norm = torch.nn.LayerNorm(dim_out) def forward(self, lora_A, lora_B): x = self.x @ lora_A.T @ lora_B.T x = self.layer_norm(x) return x class LoRAEmbedder(torch.nn.Module): def __init__(self, lora_patterns=None, L=1, out_dim=2048): super().__init__() if lora_patterns is None: lora_patterns = self.default_lora_patterns() model_dict = {} for lora_pattern in lora_patterns: name, dim = lora_pattern["name"], lora_pattern["dim"] model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim[0], dim[1]) self.model_dict = torch.nn.ModuleDict(model_dict) proj_dict = {} for lora_pattern in lora_patterns: layer_type, dim = lora_pattern["type"], lora_pattern["dim"] if layer_type not in proj_dict: proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim[1], out_dim) self.proj_dict = torch.nn.ModuleDict(proj_dict) self.lora_patterns = lora_patterns def default_lora_patterns(self): lora_patterns = [] lora_dict = { "attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432), "attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432), } for i in range(19): for suffix in lora_dict: lora_patterns.append({ "name": f"blocks.{i}.{suffix}", "dim": lora_dict[suffix], "type": suffix, }) lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)} for i in range(38): for suffix in lora_dict: lora_patterns.append({ "name": f"single_blocks.{i}.{suffix}", "dim": lora_dict[suffix], "type": suffix, }) return lora_patterns def forward(self, lora): lora_emb = [] for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] lora_A = lora[name + ".lora_A.default.weight"] lora_B = lora[name + ".lora_B.default.weight"] lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) lora_emb = torch.concat(lora_emb, dim=1) return lora_emb class FluxLoRAEncoder(torch.nn.Module): def __init__(self, embed_dim=4096, encoder_intermediate_size=8192, num_encoder_layers=1, num_embeds_per_lora=16, num_special_embeds=1): super().__init__() self.num_embeds_per_lora = num_embeds_per_lora # embedder self.embedder = LoRAEmbedder(L=num_embeds_per_lora, out_dim=embed_dim) # encoders self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=32, head_dim=128) for _ in range(num_encoder_layers)]) # special embedding self.special_embeds = torch.nn.Parameter(torch.randn(1, num_special_embeds, embed_dim)) self.num_special_embeds = num_special_embeds # final layer self.final_layer_norm = torch.nn.LayerNorm(embed_dim) self.final_linear = torch.nn.Linear(embed_dim, embed_dim) def forward(self, lora): lora_embeds = self.embedder(lora) special_embeds = self.special_embeds.to(dtype=lora_embeds.dtype, device=lora_embeds.device) embeds = torch.concat([special_embeds, lora_embeds], dim=1) for encoder_id, encoder in enumerate(self.encoders): embeds = encoder(embeds) embeds = embeds[:, :self.num_special_embeds] embeds = self.final_layer_norm(embeds) embeds = self.final_linear(embeds) return embeds @staticmethod def state_dict_converter(): return FluxLoRAEncoderStateDictConverter() class FluxLoRAEncoderStateDictConverter: def from_civitai(self, state_dict): return state_dict