Spaces:
Paused
Paused
File size: 8,016 Bytes
c1a26d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch, copy
from ..models.utils import init_weights_on_device
def cast_to(weight, dtype, device):
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight)
return r
class AutoTorchModule(torch.nn.Module):
def __init__(self):
super().__init__()
def check_free_vram(self):
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
return used_memory < self.vram_limit
def offload(self):
if self.state != 0:
self.to(dtype=self.offload_dtype, device=self.offload_device)
self.state = 0
def onload(self):
if self.state != 1:
self.to(dtype=self.onload_dtype, device=self.onload_device)
self.state = 1
def keep(self):
if self.state != 2:
self.to(dtype=self.computation_dtype, device=self.computation_device)
self.state = 2
class AutoWrappedModule(AutoTorchModule):
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
super().__init__()
self.module = module.to(dtype=offload_dtype, device=offload_device)
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def forward(self, *args, **kwargs):
if self.state == 2:
module = self.module
else:
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
module = self.module
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
module = self.module
else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
return module(*args, **kwargs)
class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
def forward(self, x, *args, **kwargs):
if self.state == 2:
weight, bias = self.weight, self.bias
else:
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
with torch.amp.autocast(device_type=x.device.type):
x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
return x
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
with init_weights_on_device(device=torch.device("meta")):
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
self.weight = module.weight
self.bias = module.bias
self.offload_dtype = offload_dtype
self.offload_device = offload_device
self.onload_dtype = onload_dtype
self.onload_device = onload_device
self.computation_dtype = computation_dtype
self.computation_device = computation_device
self.vram_limit = vram_limit
self.state = 0
self.name = name
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
def forward(self, x, *args, **kwargs):
if self.state == 2:
weight, bias = self.weight, self.bias
else:
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
weight, bias = self.weight, self.bias
elif self.vram_limit is not None and self.check_free_vram():
self.keep()
weight, bias = self.weight, self.bias
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
if len(self.lora_A_weights) == 0:
# No LoRA
return out
elif self.lora_merger is None:
# Native LoRA inference
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
out = out + x @ lora_A.T @ lora_B.T
else:
# LoRA fusion
lora_output = []
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
lora_output.append(x @ lora_A.T @ lora_B.T)
lora_output = torch.stack(lora_output)
out = self.lora_merger(out, lora_output)
return out
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
for name, module in model.named_children():
layer_name = name if name_prefix == "" else name_prefix + "." + name
for source_module, target_module in module_map.items():
if isinstance(module, source_module):
num_param = sum(p.numel() for p in module.parameters())
if max_num_param is not None and total_num_param + num_param > max_num_param:
module_config_ = overflow_module_config
else:
module_config_ = module_config
module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
setattr(model, name, module_)
total_num_param += num_param
break
else:
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
return total_num_param
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
model.vram_management_enabled = True
|