smzerbe commited on
Commit
490d2ca
·
verified ·
1 Parent(s): ca796f0

Update extensions-builtin/Lora/networks.py

Browse files
Files changed (1) hide show
  1. extensions-builtin/Lora/networks.py +737 -737
extensions-builtin/Lora/networks.py CHANGED
@@ -1,737 +1,737 @@
1
- from __future__ import annotations
2
- import gradio as gr
3
- import logging
4
- import os
5
- import re
6
-
7
- import lora_patches
8
- import network
9
- import network_lora
10
- import network_glora
11
- import network_hada
12
- import network_ia3
13
- import network_lokr
14
- import network_full
15
- import network_norm
16
- import network_oft
17
-
18
- import torch
19
- from typing import Union
20
-
21
- from modules import shared, devices, sd_models, errors, scripts, sd_hijack
22
- import modules.textual_inversion.textual_inversion as textual_inversion
23
- import modules.models.sd3.mmdit
24
-
25
- from lora_logger import logger
26
-
27
- module_types = [
28
- network_lora.ModuleTypeLora(),
29
- network_hada.ModuleTypeHada(),
30
- network_ia3.ModuleTypeIa3(),
31
- network_lokr.ModuleTypeLokr(),
32
- network_full.ModuleTypeFull(),
33
- network_norm.ModuleTypeNorm(),
34
- network_glora.ModuleTypeGLora(),
35
- network_oft.ModuleTypeOFT(),
36
- ]
37
-
38
-
39
- re_digits = re.compile(r"\d+")
40
- re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
41
- re_compiled = {}
42
-
43
- suffix_conversion = {
44
- "attentions": {},
45
- "resnets": {
46
- "conv1": "in_layers_2",
47
- "conv2": "out_layers_3",
48
- "norm1": "in_layers_0",
49
- "norm2": "out_layers_0",
50
- "time_emb_proj": "emb_layers_1",
51
- "conv_shortcut": "skip_connection",
52
- }
53
- }
54
-
55
-
56
- def convert_diffusers_name_to_compvis(key, is_sd2):
57
- def match(match_list, regex_text):
58
- regex = re_compiled.get(regex_text)
59
- if regex is None:
60
- regex = re.compile(regex_text)
61
- re_compiled[regex_text] = regex
62
-
63
- r = re.match(regex, key)
64
- if not r:
65
- return False
66
-
67
- match_list.clear()
68
- match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
69
- return True
70
-
71
- m = []
72
-
73
- if match(m, r"lora_unet_conv_in(.*)"):
74
- return f'diffusion_model_input_blocks_0_0{m[0]}'
75
-
76
- if match(m, r"lora_unet_conv_out(.*)"):
77
- return f'diffusion_model_out_2{m[0]}'
78
-
79
- if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
80
- return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
81
-
82
- if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
83
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
84
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
85
-
86
- if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
87
- suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
88
- return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
89
-
90
- if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
91
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
92
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
93
-
94
- if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
95
- return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
96
-
97
- if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
98
- return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
99
-
100
- if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
101
- if is_sd2:
102
- if 'mlp_fc1' in m[1]:
103
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
104
- elif 'mlp_fc2' in m[1]:
105
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
106
- else:
107
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
108
-
109
- return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
110
-
111
- if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
112
- if 'mlp_fc1' in m[1]:
113
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
114
- elif 'mlp_fc2' in m[1]:
115
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
116
- else:
117
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
118
-
119
- return key
120
-
121
-
122
- def assign_network_names_to_compvis_modules(sd_model):
123
- network_layer_mapping = {}
124
-
125
- if shared.sd_model.is_sdxl:
126
- for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
127
- if not hasattr(embedder, 'wrapped'):
128
- continue
129
-
130
- for name, module in embedder.wrapped.named_modules():
131
- network_name = f'{i}_{name.replace(".", "_")}'
132
- network_layer_mapping[network_name] = module
133
- module.network_layer_name = network_name
134
- else:
135
- cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
136
-
137
- for name, module in cond_stage_model.named_modules():
138
- network_name = name.replace(".", "_")
139
- network_layer_mapping[network_name] = module
140
- module.network_layer_name = network_name
141
-
142
- for name, module in shared.sd_model.model.named_modules():
143
- network_name = name.replace(".", "_")
144
- network_layer_mapping[network_name] = module
145
- module.network_layer_name = network_name
146
-
147
- sd_model.network_layer_mapping = network_layer_mapping
148
-
149
-
150
- class BundledTIHash(str):
151
- def __init__(self, hash_str):
152
- self.hash = hash_str
153
-
154
- def __str__(self):
155
- return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
156
-
157
-
158
- def load_network(name, network_on_disk):
159
- net = network.Network(name, network_on_disk)
160
- net.mtime = os.path.getmtime(network_on_disk.filename)
161
-
162
- sd = sd_models.read_state_dict(network_on_disk.filename)
163
-
164
- # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
165
- if not hasattr(shared.sd_model, 'network_layer_mapping'):
166
- assign_network_names_to_compvis_modules(shared.sd_model)
167
-
168
- keys_failed_to_match = {}
169
- is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
170
- if hasattr(shared.sd_model, 'diffusers_weight_map'):
171
- diffusers_weight_map = shared.sd_model.diffusers_weight_map
172
- elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
173
- diffusers_weight_map = {}
174
- for k, v in shared.sd_model.diffusers_weight_mapping():
175
- diffusers_weight_map[k] = v
176
- shared.sd_model.diffusers_weight_map = diffusers_weight_map
177
- else:
178
- diffusers_weight_map = None
179
-
180
- matched_networks = {}
181
- bundle_embeddings = {}
182
-
183
- for key_network, weight in sd.items():
184
-
185
- if diffusers_weight_map:
186
- key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
187
- network_part = network_name + '.' + network_weight
188
- else:
189
- key_network_without_network_parts, _, network_part = key_network.partition(".")
190
-
191
- if key_network_without_network_parts == "bundle_emb":
192
- emb_name, vec_name = network_part.split(".", 1)
193
- emb_dict = bundle_embeddings.get(emb_name, {})
194
- if vec_name.split('.')[0] == 'string_to_param':
195
- _, k2 = vec_name.split('.', 1)
196
- emb_dict['string_to_param'] = {k2: weight}
197
- else:
198
- emb_dict[vec_name] = weight
199
- bundle_embeddings[emb_name] = emb_dict
200
-
201
- if diffusers_weight_map:
202
- key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
203
- else:
204
- key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
205
-
206
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
207
-
208
- if sd_module is None:
209
- m = re_x_proj.match(key)
210
- if m:
211
- sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
212
-
213
- # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
214
- if sd_module is None and "lora_unet" in key_network_without_network_parts:
215
- key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
216
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
217
- elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
218
- key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
219
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
220
-
221
- # some SD1 Loras also have correct compvis keys
222
- if sd_module is None:
223
- key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
224
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
225
-
226
- # kohya_ss OFT module
227
- elif sd_module is None and "oft_unet" in key_network_without_network_parts:
228
- key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
229
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
230
-
231
- # KohakuBlueLeaf OFT module
232
- if sd_module is None and "oft_diag" in key:
233
- key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
234
- key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
235
- sd_module = shared.sd_model.network_layer_mapping.get(key, None)
236
-
237
- if sd_module is None:
238
- keys_failed_to_match[key_network] = key
239
- continue
240
-
241
- if key not in matched_networks:
242
- matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
243
-
244
- matched_networks[key].w[network_part] = weight
245
-
246
- for key, weights in matched_networks.items():
247
- net_module = None
248
- for nettype in module_types:
249
- net_module = nettype.create_module(net, weights)
250
- if net_module is not None:
251
- break
252
-
253
- if net_module is None:
254
- raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
255
-
256
- net.modules[key] = net_module
257
-
258
- embeddings = {}
259
- for emb_name, data in bundle_embeddings.items():
260
- embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
261
- embedding.loaded = None
262
- embedding.shorthash = BundledTIHash(name)
263
- embeddings[emb_name] = embedding
264
-
265
- net.bundle_embeddings = embeddings
266
-
267
- if keys_failed_to_match:
268
- logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
269
-
270
- return net
271
-
272
-
273
- def purge_networks_from_memory():
274
- while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
275
- name = next(iter(networks_in_memory))
276
- networks_in_memory.pop(name, None)
277
-
278
- devices.torch_gc()
279
-
280
-
281
- def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
282
- emb_db = sd_hijack.model_hijack.embedding_db
283
- already_loaded = {}
284
-
285
- for net in loaded_networks:
286
- if net.name in names:
287
- already_loaded[net.name] = net
288
- for emb_name, embedding in net.bundle_embeddings.items():
289
- if embedding.loaded:
290
- emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
291
-
292
- loaded_networks.clear()
293
-
294
- unavailable_networks = []
295
- for name in names:
296
- if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
297
- unavailable_networks.append(name)
298
- elif available_network_aliases.get(name) is None:
299
- unavailable_networks.append(name)
300
-
301
- if unavailable_networks:
302
- update_available_networks_by_names(unavailable_networks)
303
-
304
- networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
305
- if any(x is None for x in networks_on_disk):
306
- list_available_networks()
307
-
308
- networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
309
-
310
- failed_to_load_networks = []
311
-
312
- for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
313
- net = already_loaded.get(name, None)
314
-
315
- if network_on_disk is not None:
316
- if net is None:
317
- net = networks_in_memory.get(name)
318
-
319
- if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
320
- try:
321
- net = load_network(name, network_on_disk)
322
-
323
- networks_in_memory.pop(name, None)
324
- networks_in_memory[name] = net
325
- except Exception as e:
326
- errors.display(e, f"loading network {network_on_disk.filename}")
327
- continue
328
-
329
- net.mentioned_name = name
330
-
331
- network_on_disk.read_hash()
332
-
333
- if net is None:
334
- failed_to_load_networks.append(name)
335
- logging.info(f"Couldn't find network with name {name}")
336
- continue
337
-
338
- net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
339
- net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
340
- net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
341
- loaded_networks.append(net)
342
-
343
- for emb_name, embedding in net.bundle_embeddings.items():
344
- if embedding.loaded is None and emb_name in emb_db.word_embeddings:
345
- logger.warning(
346
- f'Skip bundle embedding: "{emb_name}"'
347
- ' as it was already loaded from embeddings folder'
348
- )
349
- continue
350
-
351
- embedding.loaded = False
352
- if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
353
- embedding.loaded = True
354
- emb_db.register_embedding(embedding, shared.sd_model)
355
- else:
356
- emb_db.skipped_embeddings[name] = embedding
357
-
358
- if failed_to_load_networks:
359
- lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
360
- sd_hijack.model_hijack.comments.append(lora_not_found_message)
361
- if shared.opts.lora_not_found_warning_console:
362
- print(f'\n{lora_not_found_message}\n')
363
- if shared.opts.lora_not_found_gradio_warning:
364
- gr.Warning(lora_not_found_message)
365
-
366
- purge_networks_from_memory()
367
-
368
-
369
- def allowed_layer_without_weight(layer):
370
- if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
371
- return True
372
-
373
- return False
374
-
375
-
376
- def store_weights_backup(weight):
377
- if weight is None:
378
- return None
379
-
380
- return weight.to(devices.cpu, copy=True)
381
-
382
-
383
- def restore_weights_backup(obj, field, weight):
384
- if weight is None:
385
- setattr(obj, field, None)
386
- return
387
-
388
- getattr(obj, field).copy_(weight)
389
-
390
-
391
- def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
392
- weights_backup = getattr(self, "network_weights_backup", None)
393
- bias_backup = getattr(self, "network_bias_backup", None)
394
-
395
- if weights_backup is None and bias_backup is None:
396
- return
397
-
398
- if weights_backup is not None:
399
- if isinstance(self, torch.nn.MultiheadAttention):
400
- restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
401
- restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
402
- else:
403
- restore_weights_backup(self, 'weight', weights_backup)
404
-
405
- if isinstance(self, torch.nn.MultiheadAttention):
406
- restore_weights_backup(self.out_proj, 'bias', bias_backup)
407
- else:
408
- restore_weights_backup(self, 'bias', bias_backup)
409
-
410
-
411
- def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
412
- """
413
- Applies the currently selected set of networks to the weights of torch layer self.
414
- If weights already have this particular set of networks applied, does nothing.
415
- If not, restores original weights from backup and alters weights according to networks.
416
- """
417
-
418
- network_layer_name = getattr(self, 'network_layer_name', None)
419
- if network_layer_name is None:
420
- return
421
-
422
- current_names = getattr(self, "network_current_names", ())
423
- wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
424
-
425
- weights_backup = getattr(self, "network_weights_backup", None)
426
- if weights_backup is None and wanted_names != ():
427
- if current_names != () and not allowed_layer_without_weight(self):
428
- raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
429
-
430
- if isinstance(self, torch.nn.MultiheadAttention):
431
- weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
432
- else:
433
- weights_backup = store_weights_backup(self.weight)
434
-
435
- self.network_weights_backup = weights_backup
436
-
437
- bias_backup = getattr(self, "network_bias_backup", None)
438
- if bias_backup is None and wanted_names != ():
439
- if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
440
- bias_backup = store_weights_backup(self.out_proj.bias)
441
- elif getattr(self, 'bias', None) is not None:
442
- bias_backup = store_weights_backup(self.bias)
443
- else:
444
- bias_backup = None
445
-
446
- # Unlike weight which always has value, some modules don't have bias.
447
- # Only report if bias is not None and current bias are not unchanged.
448
- if bias_backup is not None and current_names != ():
449
- raise RuntimeError("no backup bias found and current bias are not unchanged")
450
-
451
- self.network_bias_backup = bias_backup
452
-
453
- if current_names != wanted_names:
454
- network_restore_weights_from_backup(self)
455
-
456
- for net in loaded_networks:
457
- module = net.modules.get(network_layer_name, None)
458
- if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
459
- try:
460
- with torch.no_grad():
461
- if getattr(self, 'fp16_weight', None) is None:
462
- weight = self.weight
463
- bias = self.bias
464
- else:
465
- weight = self.fp16_weight.clone().to(self.weight.device)
466
- bias = getattr(self, 'fp16_bias', None)
467
- if bias is not None:
468
- bias = bias.clone().to(self.bias.device)
469
- updown, ex_bias = module.calc_updown(weight)
470
-
471
- if len(weight.shape) == 4 and weight.shape[1] == 9:
472
- # inpainting model. zero pad updown to make channel[1] 4 to 9
473
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
474
-
475
- self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
476
- if ex_bias is not None and hasattr(self, 'bias'):
477
- if self.bias is None:
478
- self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
479
- else:
480
- self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
481
- except RuntimeError as e:
482
- logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
483
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
484
-
485
- continue
486
-
487
- module_q = net.modules.get(network_layer_name + "_q_proj", None)
488
- module_k = net.modules.get(network_layer_name + "_k_proj", None)
489
- module_v = net.modules.get(network_layer_name + "_v_proj", None)
490
- module_out = net.modules.get(network_layer_name + "_out_proj", None)
491
-
492
- if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
493
- try:
494
- with torch.no_grad():
495
- # Send "real" orig_weight into MHA's lora module
496
- qw, kw, vw = self.in_proj_weight.chunk(3, 0)
497
- updown_q, _ = module_q.calc_updown(qw)
498
- updown_k, _ = module_k.calc_updown(kw)
499
- updown_v, _ = module_v.calc_updown(vw)
500
- del qw, kw, vw
501
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
502
- updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
503
-
504
- self.in_proj_weight += updown_qkv
505
- self.out_proj.weight += updown_out
506
- if ex_bias is not None:
507
- if self.out_proj.bias is None:
508
- self.out_proj.bias = torch.nn.Parameter(ex_bias)
509
- else:
510
- self.out_proj.bias += ex_bias
511
-
512
- except RuntimeError as e:
513
- logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
514
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
515
-
516
- continue
517
-
518
- if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
519
- try:
520
- with torch.no_grad():
521
- # Send "real" orig_weight into MHA's lora module
522
- qw, kw, vw = self.weight.chunk(3, 0)
523
- updown_q, _ = module_q.calc_updown(qw)
524
- updown_k, _ = module_k.calc_updown(kw)
525
- updown_v, _ = module_v.calc_updown(vw)
526
- del qw, kw, vw
527
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
528
- self.weight += updown_qkv
529
-
530
- except RuntimeError as e:
531
- logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
532
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
533
-
534
- continue
535
-
536
- if module is None:
537
- continue
538
-
539
- logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
540
- extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
541
-
542
- self.network_current_names = wanted_names
543
-
544
-
545
- def network_forward(org_module, input, original_forward):
546
- """
547
- Old way of applying Lora by executing operations during layer's forward.
548
- Stacking many loras this way results in big performance degradation.
549
- """
550
-
551
- if len(loaded_networks) == 0:
552
- return original_forward(org_module, input)
553
-
554
- input = devices.cond_cast_unet(input)
555
-
556
- network_restore_weights_from_backup(org_module)
557
- network_reset_cached_weight(org_module)
558
-
559
- y = original_forward(org_module, input)
560
-
561
- network_layer_name = getattr(org_module, 'network_layer_name', None)
562
- for lora in loaded_networks:
563
- module = lora.modules.get(network_layer_name, None)
564
- if module is None:
565
- continue
566
-
567
- y = module.forward(input, y)
568
-
569
- return y
570
-
571
-
572
- def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
573
- self.network_current_names = ()
574
- self.network_weights_backup = None
575
- self.network_bias_backup = None
576
-
577
-
578
- def network_Linear_forward(self, input):
579
- if shared.opts.lora_functional:
580
- return network_forward(self, input, originals.Linear_forward)
581
-
582
- network_apply_weights(self)
583
-
584
- return originals.Linear_forward(self, input)
585
-
586
-
587
- def network_Linear_load_state_dict(self, *args, **kwargs):
588
- network_reset_cached_weight(self)
589
-
590
- return originals.Linear_load_state_dict(self, *args, **kwargs)
591
-
592
-
593
- def network_Conv2d_forward(self, input):
594
- if shared.opts.lora_functional:
595
- return network_forward(self, input, originals.Conv2d_forward)
596
-
597
- network_apply_weights(self)
598
-
599
- return originals.Conv2d_forward(self, input)
600
-
601
-
602
- def network_Conv2d_load_state_dict(self, *args, **kwargs):
603
- network_reset_cached_weight(self)
604
-
605
- return originals.Conv2d_load_state_dict(self, *args, **kwargs)
606
-
607
-
608
- def network_GroupNorm_forward(self, input):
609
- if shared.opts.lora_functional:
610
- return network_forward(self, input, originals.GroupNorm_forward)
611
-
612
- network_apply_weights(self)
613
-
614
- return originals.GroupNorm_forward(self, input)
615
-
616
-
617
- def network_GroupNorm_load_state_dict(self, *args, **kwargs):
618
- network_reset_cached_weight(self)
619
-
620
- return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
621
-
622
-
623
- def network_LayerNorm_forward(self, input):
624
- if shared.opts.lora_functional:
625
- return network_forward(self, input, originals.LayerNorm_forward)
626
-
627
- network_apply_weights(self)
628
-
629
- return originals.LayerNorm_forward(self, input)
630
-
631
-
632
- def network_LayerNorm_load_state_dict(self, *args, **kwargs):
633
- network_reset_cached_weight(self)
634
-
635
- return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
636
-
637
-
638
- def network_MultiheadAttention_forward(self, *args, **kwargs):
639
- network_apply_weights(self)
640
-
641
- return originals.MultiheadAttention_forward(self, *args, **kwargs)
642
-
643
-
644
- def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
645
- network_reset_cached_weight(self)
646
-
647
- return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
648
-
649
-
650
- def process_network_files(names: list[str] | None = None):
651
- candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
652
- candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
653
- for filename in candidates:
654
- if os.path.isdir(filename):
655
- continue
656
- name = os.path.splitext(os.path.basename(filename))[0]
657
- # if names is provided, only load networks with names in the list
658
- if names and name not in names:
659
- continue
660
- try:
661
- entry = network.NetworkOnDisk(name, filename)
662
- except OSError: # should catch FileNotFoundError and PermissionError etc.
663
- errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
664
- continue
665
-
666
- available_networks[name] = entry
667
-
668
- if entry.alias in available_network_aliases:
669
- forbidden_network_aliases[entry.alias.lower()] = 1
670
-
671
- available_network_aliases[name] = entry
672
- available_network_aliases[entry.alias] = entry
673
-
674
-
675
- def update_available_networks_by_names(names: list[str]):
676
- process_network_files(names)
677
-
678
-
679
- def list_available_networks():
680
- available_networks.clear()
681
- available_network_aliases.clear()
682
- forbidden_network_aliases.clear()
683
- available_network_hash_lookup.clear()
684
- forbidden_network_aliases.update({"none": 1, "Addams": 1})
685
-
686
- os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
687
-
688
- process_network_files()
689
-
690
-
691
- re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
692
-
693
-
694
- def infotext_pasted(infotext, params):
695
- if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
696
- return # if the other extension is active, it will handle those fields, no need to do anything
697
-
698
- added = []
699
-
700
- for k in params:
701
- if not k.startswith("AddNet Model "):
702
- continue
703
-
704
- num = k[13:]
705
-
706
- if params.get("AddNet Module " + num) != "LoRA":
707
- continue
708
-
709
- name = params.get("AddNet Model " + num)
710
- if name is None:
711
- continue
712
-
713
- m = re_network_name.match(name)
714
- if m:
715
- name = m.group(1)
716
-
717
- multiplier = params.get("AddNet Weight A " + num, "1.0")
718
-
719
- added.append(f"<lora:{name}:{multiplier}>")
720
-
721
- if added:
722
- params["Prompt"] += "\n" + "".join(added)
723
-
724
-
725
- originals: lora_patches.LoraPatches = None
726
-
727
- extra_network_lora = None
728
-
729
- available_networks = {}
730
- available_network_aliases = {}
731
- loaded_networks = []
732
- loaded_bundle_embeddings = {}
733
- networks_in_memory = {}
734
- available_network_hash_lookup = {}
735
- forbidden_network_aliases = {}
736
-
737
- list_available_networks()
 
1
+ from __future__ import annotations
2
+ import gradio as gr
3
+ import logging
4
+ import os
5
+ import re
6
+
7
+ import lora_patches
8
+ import network
9
+ import network_lora
10
+ import network_glora
11
+ import network_hada
12
+ import network_ia3
13
+ import network_lokr
14
+ import network_full
15
+ import network_norm
16
+ import network_oft
17
+
18
+ import torch
19
+ from typing import Union
20
+
21
+ from modules import shared, devices, sd_models, errors, scripts, sd_hijack
22
+ import modules.textual_inversion.textual_inversion as textual_inversion
23
+ import modules.models.sd3.mmdit
24
+
25
+ from lora_logger import logger
26
+
27
+ module_types = [
28
+ network_lora.ModuleTypeLora(),
29
+ network_hada.ModuleTypeHada(),
30
+ network_ia3.ModuleTypeIa3(),
31
+ network_lokr.ModuleTypeLokr(),
32
+ network_full.ModuleTypeFull(),
33
+ network_norm.ModuleTypeNorm(),
34
+ network_glora.ModuleTypeGLora(),
35
+ network_oft.ModuleTypeOFT(),
36
+ ]
37
+
38
+
39
+ re_digits = re.compile(r"\d+")
40
+ re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
41
+ re_compiled = {}
42
+
43
+ suffix_conversion = {
44
+ "attentions": {},
45
+ "resnets": {
46
+ "conv1": "in_layers_2",
47
+ "conv2": "out_layers_3",
48
+ "norm1": "in_layers_0",
49
+ "norm2": "out_layers_0",
50
+ "time_emb_proj": "emb_layers_1",
51
+ "conv_shortcut": "skip_connection",
52
+ }
53
+ }
54
+
55
+
56
+ def convert_diffusers_name_to_compvis(key, is_sd2):
57
+ def match(match_list, regex_text):
58
+ regex = re_compiled.get(regex_text)
59
+ if regex is None:
60
+ regex = re.compile(regex_text)
61
+ re_compiled[regex_text] = regex
62
+
63
+ r = re.match(regex, key)
64
+ if not r:
65
+ return False
66
+
67
+ match_list.clear()
68
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
69
+ return True
70
+
71
+ m = []
72
+
73
+ if match(m, r"lora_unet_conv_in(.*)"):
74
+ return f'diffusion_model_input_blocks_0_0{m[0]}'
75
+
76
+ if match(m, r"lora_unet_conv_out(.*)"):
77
+ return f'diffusion_model_out_2{m[0]}'
78
+
79
+ if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
80
+ return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
81
+
82
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
83
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
84
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
85
+
86
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
87
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
88
+ return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
89
+
90
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
91
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
92
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
93
+
94
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
95
+ return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
96
+
97
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
98
+ return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
99
+
100
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
101
+ if is_sd2:
102
+ if 'mlp_fc1' in m[1]:
103
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
104
+ elif 'mlp_fc2' in m[1]:
105
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
106
+ else:
107
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
108
+
109
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
110
+
111
+ if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
112
+ if 'mlp_fc1' in m[1]:
113
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
114
+ elif 'mlp_fc2' in m[1]:
115
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
116
+ else:
117
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
118
+
119
+ return key
120
+
121
+
122
+ def assign_network_names_to_compvis_modules(sd_model):
123
+ network_layer_mapping = {}
124
+
125
+ if shared.sd_model.is_sdxl:
126
+ for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
127
+ if not hasattr(embedder, 'wrapped'):
128
+ continue
129
+
130
+ for name, module in embedder.wrapped.named_modules():
131
+ network_name = f'{i}_{name.replace(".", "_")}'
132
+ network_layer_mapping[network_name] = module
133
+ module.network_layer_name = network_name
134
+ else:
135
+ cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
136
+
137
+ for name, module in cond_stage_model.named_modules():
138
+ network_name = name.replace(".", "_")
139
+ network_layer_mapping[network_name] = module
140
+ module.network_layer_name = network_name
141
+
142
+ for name, module in shared.sd_model.model.named_modules():
143
+ network_name = name.replace(".", "_")
144
+ network_layer_mapping[network_name] = module
145
+ module.network_layer_name = network_name
146
+
147
+ sd_model.network_layer_mapping = network_layer_mapping
148
+
149
+
150
+ class BundledTIHash(str):
151
+ def __init__(self, hash_str):
152
+ self.hash = hash_str
153
+
154
+ def __str__(self):
155
+ return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
156
+
157
+
158
+ def load_network(name, network_on_disk):
159
+ net = network.Network(name, network_on_disk)
160
+ net.mtime = os.path.getmtime(network_on_disk.filename)
161
+
162
+ sd = sd_models.read_state_dict(network_on_disk.filename)
163
+
164
+ # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
165
+ if not hasattr(shared.sd_model, 'network_layer_mapping'):
166
+ assign_network_names_to_compvis_modules(shared.sd_model)
167
+
168
+ keys_failed_to_match = {}
169
+ is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
170
+ if hasattr(shared.sd_model, 'diffusers_weight_map'):
171
+ diffusers_weight_map = shared.sd_model.diffusers_weight_map
172
+ elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
173
+ diffusers_weight_map = {}
174
+ for k, v in shared.sd_model.diffusers_weight_mapping():
175
+ diffusers_weight_map[k] = v
176
+ shared.sd_model.diffusers_weight_map = diffusers_weight_map
177
+ else:
178
+ diffusers_weight_map = None
179
+
180
+ matched_networks = {}
181
+ bundle_embeddings = {}
182
+
183
+ for key_network, weight in sd.items():
184
+
185
+ if diffusers_weight_map:
186
+ key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
187
+ network_part = network_name + '.' + network_weight
188
+ else:
189
+ key_network_without_network_parts, _, network_part = key_network.partition(".")
190
+
191
+ if key_network_without_network_parts == "bundle_emb":
192
+ emb_name, vec_name = network_part.split(".", 1)
193
+ emb_dict = bundle_embeddings.get(emb_name, {})
194
+ if vec_name.split('.')[0] == 'string_to_param':
195
+ _, k2 = vec_name.split('.', 1)
196
+ emb_dict['string_to_param'] = {k2: weight}
197
+ else:
198
+ emb_dict[vec_name] = weight
199
+ bundle_embeddings[emb_name] = emb_dict
200
+
201
+ if diffusers_weight_map:
202
+ key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
203
+ else:
204
+ key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
205
+
206
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
207
+
208
+ if sd_module is None:
209
+ m = re_x_proj.match(key)
210
+ if m:
211
+ sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
212
+
213
+ # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
214
+ if sd_module is None and "lora_unet" in key_network_without_network_parts:
215
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
216
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
217
+ elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
218
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
219
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
220
+
221
+ # some SD1 Loras also have correct compvis keys
222
+ if sd_module is None:
223
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
224
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
225
+
226
+ # kohya_ss OFT module
227
+ elif sd_module is None and "oft_unet" in key_network_without_network_parts:
228
+ key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
229
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
230
+
231
+ # KohakuBlueLeaf OFT module
232
+ if sd_module is None and "oft_diag" in key:
233
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
234
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
235
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
236
+
237
+ if sd_module is None:
238
+ keys_failed_to_match[key_network] = key
239
+ continue
240
+
241
+ if key not in matched_networks:
242
+ matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
243
+
244
+ matched_networks[key].w[network_part] = weight
245
+
246
+ for key, weights in matched_networks.items():
247
+ net_module = None
248
+ for nettype in module_types:
249
+ net_module = nettype.create_module(net, weights)
250
+ if net_module is not None:
251
+ break
252
+
253
+ if net_module is None:
254
+ raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
255
+
256
+ net.modules[key] = net_module
257
+
258
+ embeddings = {}
259
+ for emb_name, data in bundle_embeddings.items():
260
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
261
+ embedding.loaded = None
262
+ embedding.shorthash = BundledTIHash(name)
263
+ embeddings[emb_name] = embedding
264
+
265
+ net.bundle_embeddings = embeddings
266
+
267
+ if keys_failed_to_match:
268
+ logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
269
+
270
+ return net
271
+
272
+
273
+ def purge_networks_from_memory():
274
+ while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
275
+ name = next(iter(networks_in_memory))
276
+ networks_in_memory.pop(name, None)
277
+
278
+ devices.torch_gc()
279
+
280
+
281
+ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
282
+ emb_db = sd_hijack.model_hijack.embedding_db
283
+ already_loaded = {}
284
+
285
+ for net in loaded_networks:
286
+ if net.name in names:
287
+ already_loaded[net.name] = net
288
+ for emb_name, embedding in net.bundle_embeddings.items():
289
+ if embedding.loaded:
290
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
291
+
292
+ loaded_networks.clear()
293
+
294
+ unavailable_networks = []
295
+ for name in names:
296
+ if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
297
+ unavailable_networks.append(name)
298
+ elif available_network_aliases.get(name) is None:
299
+ unavailable_networks.append(name)
300
+
301
+ if unavailable_networks:
302
+ update_available_networks_by_names(unavailable_networks)
303
+
304
+ networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
305
+ if any(x is None for x in networks_on_disk):
306
+ list_available_networks()
307
+
308
+ networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
309
+
310
+ failed_to_load_networks = []
311
+
312
+ for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
313
+ net = already_loaded.get(name, None)
314
+
315
+ if network_on_disk is not None:
316
+ if net is None:
317
+ net = networks_in_memory.get(name)
318
+
319
+ if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
320
+ try:
321
+ net = load_network(name, network_on_disk)
322
+
323
+ networks_in_memory.pop(name, None)
324
+ networks_in_memory[name] = net
325
+ except Exception as e:
326
+ errors.display(e, f"loading network {network_on_disk.filename}")
327
+ continue
328
+
329
+ net.mentioned_name = name
330
+
331
+ network_on_disk.read_hash()
332
+
333
+ if net is None:
334
+ failed_to_load_networks.append(name)
335
+ logging.info(f"Couldn't find network with name {name}")
336
+ continue
337
+
338
+ net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
339
+ net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
340
+ net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
341
+ loaded_networks.append(net)
342
+
343
+ for emb_name, embedding in net.bundle_embeddings.items():
344
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
345
+ logger.warning(
346
+ f'Skip bundle embedding: "{emb_name}"'
347
+ ' as it was already loaded from embeddings folder'
348
+ )
349
+ continue
350
+
351
+ embedding.loaded = False
352
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
353
+ embedding.loaded = True
354
+ emb_db.register_embedding(embedding, shared.sd_model)
355
+ else:
356
+ emb_db.skipped_embeddings[name] = embedding
357
+
358
+ if failed_to_load_networks:
359
+ lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
360
+ sd_hijack.model_hijack.comments.append(lora_not_found_message)
361
+ if shared.opts.lora_not_found_warning_console:
362
+ print(f'\n{lora_not_found_message}\n')
363
+ if shared.opts.lora_not_found_gradio_warning:
364
+ gr.Warning(lora_not_found_message)
365
+
366
+ purge_networks_from_memory()
367
+
368
+
369
+ def allowed_layer_without_weight(layer):
370
+ if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
371
+ return True
372
+
373
+ return False
374
+
375
+
376
+ def store_weights_backup(weight):
377
+ if weight is None:
378
+ return None
379
+
380
+ return weight.to(devices.cpu, copy=True)
381
+
382
+
383
+ def restore_weights_backup(obj, field, weight):
384
+ if weight is None:
385
+ setattr(obj, field, None)
386
+ return
387
+
388
+ getattr(obj, field).copy_(weight)
389
+
390
+
391
+ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
392
+ weights_backup = getattr(self, "network_weights_backup", None)
393
+ bias_backup = getattr(self, "network_bias_backup", None)
394
+
395
+ if weights_backup is None and bias_backup is None:
396
+ return
397
+
398
+ if weights_backup is not None:
399
+ if isinstance(self, torch.nn.MultiheadAttention):
400
+ restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
401
+ restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
402
+ else:
403
+ restore_weights_backup(self, 'weight', weights_backup)
404
+
405
+ if isinstance(self, torch.nn.MultiheadAttention):
406
+ restore_weights_backup(self.out_proj, 'bias', bias_backup)
407
+ else:
408
+ restore_weights_backup(self, 'bias', bias_backup)
409
+
410
+
411
+ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
412
+ """
413
+ Applies the currently selected set of networks to the weights of torch layer self.
414
+ If weights already have this particular set of networks applied, does nothing.
415
+ If not, restores original weights from backup and alters weights according to networks.
416
+ """
417
+
418
+ network_layer_name = getattr(self, 'network_layer_name', None)
419
+ if network_layer_name is None:
420
+ return
421
+
422
+ current_names = getattr(self, "network_current_names", ())
423
+ wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
424
+
425
+ weights_backup = getattr(self, "network_weights_backup", None)
426
+ if weights_backup is None and wanted_names != ():
427
+ if current_names != () and not allowed_layer_without_weight(self):
428
+ raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
429
+
430
+ if isinstance(self, torch.nn.MultiheadAttention):
431
+ weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
432
+ else:
433
+ weights_backup = store_weights_backup(self.weight)
434
+
435
+ self.network_weights_backup = weights_backup
436
+
437
+ bias_backup = getattr(self, "network_bias_backup", None)
438
+ if bias_backup is None and wanted_names != ():
439
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
440
+ bias_backup = store_weights_backup(self.out_proj.bias)
441
+ elif getattr(self, 'bias', None) is not None:
442
+ bias_backup = store_weights_backup(self.bias)
443
+ else:
444
+ bias_backup = None
445
+
446
+ # Unlike weight which always has value, some modules don't have bias.
447
+ # Only report if bias is not None and current bias are not unchanged.
448
+ if bias_backup is not None and current_names != ():
449
+ raise RuntimeError("no backup bias found and current bias are not unchanged")
450
+
451
+ self.network_bias_backup = bias_backup
452
+
453
+ if current_names != wanted_names:
454
+ network_restore_weights_from_backup(self)
455
+
456
+ for net in loaded_networks:
457
+ module = net.modules.get(network_layer_name, None)
458
+ if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
459
+ try:
460
+ with torch.no_grad():
461
+ if getattr(self, 'fp16_weight', None) is None:
462
+ weight = self.weight
463
+ bias = self.bias
464
+ else:
465
+ weight = self.fp16_weight.clone().to(self.weight.device)
466
+ bias = getattr(self, 'fp16_bias', None)
467
+ if bias is not None:
468
+ bias = bias.clone().to(self.bias.device)
469
+ updown, ex_bias = module.calc_updown(weight)
470
+
471
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
472
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
473
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
474
+
475
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
476
+ if ex_bias is not None and hasattr(self, 'bias'):
477
+ if self.bias is None:
478
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
479
+ else:
480
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
481
+ except RuntimeError as e:
482
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
483
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
484
+
485
+ continue
486
+
487
+ module_q = net.modules.get(network_layer_name + "_q_proj", None)
488
+ module_k = net.modules.get(network_layer_name + "_k_proj", None)
489
+ module_v = net.modules.get(network_layer_name + "_v_proj", None)
490
+ module_out = net.modules.get(network_layer_name + "_out_proj", None)
491
+
492
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
493
+ try:
494
+ with torch.no_grad():
495
+ # Send "real" orig_weight into MHA's lora module
496
+ qw, kw, vw = self.in_proj_weight.chunk(3, 0)
497
+ updown_q, _ = module_q.calc_updown(qw)
498
+ updown_k, _ = module_k.calc_updown(kw)
499
+ updown_v, _ = module_v.calc_updown(vw)
500
+ del qw, kw, vw
501
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
502
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
503
+
504
+ self.in_proj_weight += updown_qkv
505
+ self.out_proj.weight += updown_out
506
+ if ex_bias is not None:
507
+ if self.out_proj.bias is None:
508
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
509
+ else:
510
+ self.out_proj.bias += ex_bias
511
+
512
+ except RuntimeError as e:
513
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
514
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
515
+
516
+ continue
517
+
518
+ if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
519
+ try:
520
+ with torch.no_grad():
521
+ # Send "real" orig_weight into MHA's lora module
522
+ qw, kw, vw = self.weight.chunk(3, 0)
523
+ updown_q, _ = module_q.calc_updown(qw)
524
+ updown_k, _ = module_k.calc_updown(kw)
525
+ updown_v, _ = module_v.calc_updown(vw)
526
+ del qw, kw, vw
527
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
528
+ self.weight += updown_qkv
529
+
530
+ except RuntimeError as e:
531
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
532
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
533
+
534
+ continue
535
+
536
+ if module is None:
537
+ continue
538
+
539
+ logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
540
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
541
+
542
+ self.network_current_names = wanted_names
543
+
544
+
545
+ def network_forward(org_module, input, original_forward):
546
+ """
547
+ Old way of applying Lora by executing operations during layer's forward.
548
+ Stacking many loras this way results in big performance degradation.
549
+ """
550
+
551
+ if len(loaded_networks) == 0:
552
+ return original_forward(org_module, input)
553
+
554
+ input = devices.cond_cast_unet(input)
555
+
556
+ network_restore_weights_from_backup(org_module)
557
+ network_reset_cached_weight(org_module)
558
+
559
+ y = original_forward(org_module, input)
560
+
561
+ network_layer_name = getattr(org_module, 'network_layer_name', None)
562
+ for lora in loaded_networks:
563
+ module = lora.modules.get(network_layer_name, None)
564
+ if module is None:
565
+ continue
566
+
567
+ y = module.forward(input, y)
568
+
569
+ return y
570
+
571
+
572
+ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
573
+ self.network_current_names = ()
574
+ self.network_weights_backup = None
575
+ self.network_bias_backup = None
576
+
577
+
578
+ def network_Linear_forward(self, input):
579
+ if shared.opts.lora_functional:
580
+ return network_forward(self, input, originals.Linear_forward)
581
+
582
+ network_apply_weights(self)
583
+
584
+ return originals.Linear_forward(self, input.float())
585
+
586
+
587
+ def network_Linear_load_state_dict(self, *args, **kwargs):
588
+ network_reset_cached_weight(self)
589
+
590
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
591
+
592
+
593
+ def network_Conv2d_forward(self, input):
594
+ if shared.opts.lora_functional:
595
+ return network_forward(self, input, originals.Conv2d_forward)
596
+
597
+ network_apply_weights(self)
598
+
599
+ return originals.Conv2d_forward(self, input)
600
+
601
+
602
+ def network_Conv2d_load_state_dict(self, *args, **kwargs):
603
+ network_reset_cached_weight(self)
604
+
605
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
606
+
607
+
608
+ def network_GroupNorm_forward(self, input):
609
+ if shared.opts.lora_functional:
610
+ return network_forward(self, input, originals.GroupNorm_forward)
611
+
612
+ network_apply_weights(self)
613
+
614
+ return originals.GroupNorm_forward(self, input)
615
+
616
+
617
+ def network_GroupNorm_load_state_dict(self, *args, **kwargs):
618
+ network_reset_cached_weight(self)
619
+
620
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
621
+
622
+
623
+ def network_LayerNorm_forward(self, input):
624
+ if shared.opts.lora_functional:
625
+ return network_forward(self, input, originals.LayerNorm_forward)
626
+
627
+ network_apply_weights(self)
628
+
629
+ return originals.LayerNorm_forward(self, input)
630
+
631
+
632
+ def network_LayerNorm_load_state_dict(self, *args, **kwargs):
633
+ network_reset_cached_weight(self)
634
+
635
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
636
+
637
+
638
+ def network_MultiheadAttention_forward(self, *args, **kwargs):
639
+ network_apply_weights(self)
640
+
641
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
642
+
643
+
644
+ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
645
+ network_reset_cached_weight(self)
646
+
647
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
648
+
649
+
650
+ def process_network_files(names: list[str] | None = None):
651
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
652
+ candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
653
+ for filename in candidates:
654
+ if os.path.isdir(filename):
655
+ continue
656
+ name = os.path.splitext(os.path.basename(filename))[0]
657
+ # if names is provided, only load networks with names in the list
658
+ if names and name not in names:
659
+ continue
660
+ try:
661
+ entry = network.NetworkOnDisk(name, filename)
662
+ except OSError: # should catch FileNotFoundError and PermissionError etc.
663
+ errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
664
+ continue
665
+
666
+ available_networks[name] = entry
667
+
668
+ if entry.alias in available_network_aliases:
669
+ forbidden_network_aliases[entry.alias.lower()] = 1
670
+
671
+ available_network_aliases[name] = entry
672
+ available_network_aliases[entry.alias] = entry
673
+
674
+
675
+ def update_available_networks_by_names(names: list[str]):
676
+ process_network_files(names)
677
+
678
+
679
+ def list_available_networks():
680
+ available_networks.clear()
681
+ available_network_aliases.clear()
682
+ forbidden_network_aliases.clear()
683
+ available_network_hash_lookup.clear()
684
+ forbidden_network_aliases.update({"none": 1, "Addams": 1})
685
+
686
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
687
+
688
+ process_network_files()
689
+
690
+
691
+ re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
692
+
693
+
694
+ def infotext_pasted(infotext, params):
695
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
696
+ return # if the other extension is active, it will handle those fields, no need to do anything
697
+
698
+ added = []
699
+
700
+ for k in params:
701
+ if not k.startswith("AddNet Model "):
702
+ continue
703
+
704
+ num = k[13:]
705
+
706
+ if params.get("AddNet Module " + num) != "LoRA":
707
+ continue
708
+
709
+ name = params.get("AddNet Model " + num)
710
+ if name is None:
711
+ continue
712
+
713
+ m = re_network_name.match(name)
714
+ if m:
715
+ name = m.group(1)
716
+
717
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
718
+
719
+ added.append(f"<lora:{name}:{multiplier}>")
720
+
721
+ if added:
722
+ params["Prompt"] += "\n" + "".join(added)
723
+
724
+
725
+ originals: lora_patches.LoraPatches = None
726
+
727
+ extra_network_lora = None
728
+
729
+ available_networks = {}
730
+ available_network_aliases = {}
731
+ loaded_networks = []
732
+ loaded_bundle_embeddings = {}
733
+ networks_in_memory = {}
734
+ available_network_hash_lookup = {}
735
+ forbidden_network_aliases = {}
736
+
737
+ list_available_networks()