model_structure_viewer / backend /hf_model_utils.py
maomao88's picture
add loading icon
b0223ae
import torch
import torch.nn as nn
import json
import hashlib
import gc
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
from accelerate import init_empty_weights
model_structure_cache = {}
def module_hash(module):
"""
Generate a hash representing the structure of a module.
Uses class name + child hashes + param shapes to detect repeats.
"""
children = list(module.named_children())
child_hashes = []
for name, child in children:
child_hashes.append(module_hash(child))
# Include class name and param shapes
param_info = [(name, tuple(p.shape), p.requires_grad)
for name, p in module.named_parameters(recurse=False)]
rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info))
rep_bytes = str(rep).encode('utf-8')
return hashlib.md5(rep_bytes).hexdigest()
def is_number_string(value):
return isinstance(value, str) and value.isdigit()
def hf_style_structural_dict(module):
"""
Recursively convert a PyTorch module into a dict mirroring
Hugging Face's print(model), only counting repeats when structure is identical.
"""
children = list(module.named_children())
result = {"class_name": module.__class__.__name__}
# Include params if present
params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad}
for name, p in module.named_parameters(recurse=False)}
if params:
result["params"] = params
if children:
child_dict = {}
i = 0
while i < len(children):
name, child = children[i]
current_hash = module_hash(child)
count = 1
j = i + 1
# Count consecutive children that are structurally identical
while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash:
count += 1
j += 1
child_entry = hf_style_structural_dict(child)
if count > 1:
child_entry["num_repeats"] = count
child_dict[name] = child_entry
i += count
result["children"] = child_dict
return result
def get_model_structure(model_name: str, model_type: str | None):
# 1. Check if it's already cached
if model_name in model_structure_cache:
return model_structure_cache[model_name]
print(model_type)
# 2. If not cached, build the structure
if model_type == "causal":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
elif model_type == "masked":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
elif model_type == "sequence":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSequenceClassification.from_config(config)
elif model_type == "token":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForTokenClassification.from_config(config)
elif model_type == "qa":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForQuestionAnswering.from_config(config)
elif model_type == "s2s":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
elif model_type == "vision":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForImageClassification.from_config(config)
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with torch.device("meta"):
model = AutoModel.from_config(config, trust_remote_code=True)
structure = {
"model_type": config.model_type,
"hidden_size": getattr(config, "hidden_size", None),
"num_hidden_layers": getattr(config, "num_hidden_layers", None),
"num_attention_heads": getattr(config, "num_attention_heads", None),
"image_size": getattr(config, "image_size", None),
"intermediate_size": getattr(config, "intermediate_size", None),
"patch_size": getattr(config, "patch_size", None),
"vocab_size": getattr(config, "vocab_size", None),
"layers": hf_style_structural_dict(model)
}
# 3. Free memory
del model
gc.collect()
torch.cuda.empty_cache() # only if using GPU
# 4. Save JSON in cache
model_structure_cache[model_name] = structure
return structure # JSON-serializable