import torch import torch.nn as nn import json import hashlib import gc from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification from accelerate import init_empty_weights model_structure_cache = {} def module_hash(module): """ Generate a hash representing the structure of a module. Uses class name + child hashes + param shapes to detect repeats. """ children = list(module.named_children()) child_hashes = [] for name, child in children: child_hashes.append(module_hash(child)) # Include class name and param shapes param_info = [(name, tuple(p.shape), p.requires_grad) for name, p in module.named_parameters(recurse=False)] rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info)) rep_bytes = str(rep).encode('utf-8') return hashlib.md5(rep_bytes).hexdigest() def is_number_string(value): return isinstance(value, str) and value.isdigit() def hf_style_structural_dict(module): """ Recursively convert a PyTorch module into a dict mirroring Hugging Face's print(model), only counting repeats when structure is identical. """ children = list(module.named_children()) result = {"class_name": module.__class__.__name__} # Include params if present params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad} for name, p in module.named_parameters(recurse=False)} if params: result["params"] = params if children: child_dict = {} i = 0 while i < len(children): name, child = children[i] current_hash = module_hash(child) count = 1 j = i + 1 # Count consecutive children that are structurally identical while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash: count += 1 j += 1 child_entry = hf_style_structural_dict(child) if count > 1: child_entry["num_repeats"] = count child_dict[name] = child_entry i += count result["children"] = child_dict return result def get_model_structure(model_name: str, model_type: str | None): # 1. Check if it's already cached if model_name in model_structure_cache: return model_structure_cache[model_name] print(model_type) # 2. If not cached, build the structure if model_type == "causal": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) elif model_type == "masked": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForMaskedLM.from_config(config) elif model_type == "sequence": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForSequenceClassification.from_config(config) elif model_type == "token": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForTokenClassification.from_config(config) elif model_type == "qa": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForQuestionAnswering.from_config(config) elif model_type == "s2s": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) elif model_type == "vision": config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with init_empty_weights(): model = AutoModelForImageClassification.from_config(config) else: config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) with torch.device("meta"): model = AutoModel.from_config(config, trust_remote_code=True) structure = { "model_type": config.model_type, "hidden_size": getattr(config, "hidden_size", None), "num_hidden_layers": getattr(config, "num_hidden_layers", None), "num_attention_heads": getattr(config, "num_attention_heads", None), "image_size": getattr(config, "image_size", None), "intermediate_size": getattr(config, "intermediate_size", None), "patch_size": getattr(config, "patch_size", None), "vocab_size": getattr(config, "vocab_size", None), "layers": hf_style_structural_dict(model) } # 3. Free memory del model gc.collect() torch.cuda.empty_cache() # only if using GPU # 4. Save JSON in cache model_structure_cache[model_name] = structure return structure # JSON-serializable