Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import json | |
| import hashlib | |
| import gc | |
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification | |
| from accelerate import init_empty_weights | |
| model_structure_cache = {} | |
| def module_hash(module): | |
| """ | |
| Generate a hash representing the structure of a module. | |
| Uses class name + child hashes + param shapes to detect repeats. | |
| """ | |
| children = list(module.named_children()) | |
| child_hashes = [] | |
| for name, child in children: | |
| child_hashes.append(module_hash(child)) | |
| # Include class name and param shapes | |
| param_info = [(name, tuple(p.shape), p.requires_grad) | |
| for name, p in module.named_parameters(recurse=False)] | |
| rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info)) | |
| rep_bytes = str(rep).encode('utf-8') | |
| return hashlib.md5(rep_bytes).hexdigest() | |
| def is_number_string(value): | |
| return isinstance(value, str) and value.isdigit() | |
| def hf_style_structural_dict(module): | |
| """ | |
| Recursively convert a PyTorch module into a dict mirroring | |
| Hugging Face's print(model), only counting repeats when structure is identical. | |
| """ | |
| children = list(module.named_children()) | |
| result = {"class_name": module.__class__.__name__} | |
| # Include params if present | |
| params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad} | |
| for name, p in module.named_parameters(recurse=False)} | |
| if params: | |
| result["params"] = params | |
| if children: | |
| child_dict = {} | |
| i = 0 | |
| while i < len(children): | |
| name, child = children[i] | |
| current_hash = module_hash(child) | |
| count = 1 | |
| j = i + 1 | |
| # Count consecutive children that are structurally identical | |
| while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash: | |
| count += 1 | |
| j += 1 | |
| child_entry = hf_style_structural_dict(child) | |
| if count > 1: | |
| child_entry["num_repeats"] = count | |
| child_dict[name] = child_entry | |
| i += count | |
| result["children"] = child_dict | |
| return result | |
| def get_model_structure(model_name: str, model_type: str | None): | |
| # 1. Check if it's already cached | |
| if model_name in model_structure_cache: | |
| return model_structure_cache[model_name] | |
| print(model_type) | |
| # 2. If not cached, build the structure | |
| if model_type == "causal": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) | |
| elif model_type == "masked": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForMaskedLM.from_config(config) | |
| elif model_type == "sequence": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForSequenceClassification.from_config(config) | |
| elif model_type == "token": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForTokenClassification.from_config(config) | |
| elif model_type == "qa": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForQuestionAnswering.from_config(config) | |
| elif model_type == "s2s": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForSeq2SeqLM.from_config(config) | |
| elif model_type == "vision": | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with init_empty_weights(): | |
| model = AutoModelForImageClassification.from_config(config) | |
| else: | |
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| with torch.device("meta"): | |
| model = AutoModel.from_config(config, trust_remote_code=True) | |
| structure = { | |
| "model_type": config.model_type, | |
| "hidden_size": getattr(config, "hidden_size", None), | |
| "num_hidden_layers": getattr(config, "num_hidden_layers", None), | |
| "num_attention_heads": getattr(config, "num_attention_heads", None), | |
| "image_size": getattr(config, "image_size", None), | |
| "intermediate_size": getattr(config, "intermediate_size", None), | |
| "patch_size": getattr(config, "patch_size", None), | |
| "vocab_size": getattr(config, "vocab_size", None), | |
| "layers": hf_style_structural_dict(model) | |
| } | |
| # 3. Free memory | |
| del model | |
| gc.collect() | |
| torch.cuda.empty_cache() # only if using GPU | |
| # 4. Save JSON in cache | |
| model_structure_cache[model_name] = structure | |
| return structure # JSON-serializable |