Spaces:
Runtime error
Runtime error
File size: 5,171 Bytes
d30efe2 119afbd c3173d9 d30efe2 a084789 d30efe2 a084789 d30efe2 c3173d9 119afbd c3173d9 119afbd c3173d9 a084789 c3173d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import json
import hashlib
import gc
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
from accelerate import init_empty_weights
model_structure_cache = {}
def module_hash(module):
"""
Generate a hash representing the structure of a module.
Uses class name + child hashes + param shapes to detect repeats.
"""
children = list(module.named_children())
child_hashes = []
for name, child in children:
child_hashes.append(module_hash(child))
# Include class name and param shapes
param_info = [(name, tuple(p.shape), p.requires_grad)
for name, p in module.named_parameters(recurse=False)]
rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info))
rep_bytes = str(rep).encode('utf-8')
return hashlib.md5(rep_bytes).hexdigest()
def is_number_string(value):
return isinstance(value, str) and value.isdigit()
def hf_style_structural_dict(module):
"""
Recursively convert a PyTorch module into a dict mirroring
Hugging Face's print(model), only counting repeats when structure is identical.
"""
children = list(module.named_children())
result = {"class_name": module.__class__.__name__}
# Include params if present
params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad}
for name, p in module.named_parameters(recurse=False)}
if params:
result["params"] = params
if children:
child_dict = {}
i = 0
while i < len(children):
name, child = children[i]
current_hash = module_hash(child)
count = 1
j = i + 1
# Count consecutive children that are structurally identical
while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash:
count += 1
j += 1
child_entry = hf_style_structural_dict(child)
if count > 1:
child_entry["num_repeats"] = count
child_dict[name] = child_entry
i += count
result["children"] = child_dict
return result
def get_model_structure(model_name: str, model_type: str | None):
# 1. Check if it's already cached
if model_name in model_structure_cache:
return model_structure_cache[model_name]
print(model_type)
# 2. If not cached, build the structure
if model_type == "causal":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
elif model_type == "masked":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
elif model_type == "sequence":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSequenceClassification.from_config(config)
elif model_type == "token":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForTokenClassification.from_config(config)
elif model_type == "qa":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForQuestionAnswering.from_config(config)
elif model_type == "s2s":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
elif model_type == "vision":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForImageClassification.from_config(config)
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with torch.device("meta"):
model = AutoModel.from_config(config, trust_remote_code=True)
structure = {
"model_type": config.model_type,
"hidden_size": getattr(config, "hidden_size", None),
"num_hidden_layers": getattr(config, "num_hidden_layers", None),
"num_attention_heads": getattr(config, "num_attention_heads", None),
"image_size": getattr(config, "image_size", None),
"intermediate_size": getattr(config, "intermediate_size", None),
"patch_size": getattr(config, "patch_size", None),
"vocab_size": getattr(config, "vocab_size", None),
"layers": hf_style_structural_dict(model)
}
# 3. Free memory
del model
gc.collect()
torch.cuda.empty_cache() # only if using GPU
# 4. Save JSON in cache
model_structure_cache[model_name] = structure
return structure # JSON-serializable |