File size: 5,171 Bytes
d30efe2
 
 
 
 
119afbd
c3173d9
 
 
 
d30efe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a084789
 
d30efe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a084789
d30efe2
 
 
 
 
 
 
 
 
 
c3173d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119afbd
 
c3173d9
 
119afbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3173d9
 
 
 
 
 
 
 
 
 
a084789
 
 
 
c3173d9
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn as nn
import json
import hashlib
import gc
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
from accelerate import init_empty_weights


model_structure_cache = {}


def module_hash(module):
    """
    Generate a hash representing the structure of a module.
    Uses class name + child hashes + param shapes to detect repeats.
    """
    children = list(module.named_children())
    child_hashes = []
    for name, child in children:
        child_hashes.append(module_hash(child))

    # Include class name and param shapes
    param_info = [(name, tuple(p.shape), p.requires_grad)
                  for name, p in module.named_parameters(recurse=False)]

    rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info))
    rep_bytes = str(rep).encode('utf-8')
    return hashlib.md5(rep_bytes).hexdigest()

def is_number_string(value):
    return isinstance(value, str) and value.isdigit()

def hf_style_structural_dict(module):
    """
    Recursively convert a PyTorch module into a dict mirroring
    Hugging Face's print(model), only counting repeats when structure is identical.
    """
    children = list(module.named_children())
    result = {"class_name": module.__class__.__name__}

    # Include params if present
    params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad}
              for name, p in module.named_parameters(recurse=False)}
    if params:
        result["params"] = params

    if children:
        child_dict = {}
        i = 0
        while i < len(children):
            name, child = children[i]
            current_hash = module_hash(child)
            count = 1
            j = i + 1
            # Count consecutive children that are structurally identical
            while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash:
                count += 1
                j += 1

            child_entry = hf_style_structural_dict(child)
            if count > 1:
                child_entry["num_repeats"] = count
            child_dict[name] = child_entry
            i += count
        result["children"] = child_dict

    return result



def get_model_structure(model_name: str, model_type: str | None):
    # 1. Check if it's already cached
    if model_name in model_structure_cache:
        return model_structure_cache[model_name]

    print(model_type)
    # 2. If not cached, build the structure
    if model_type == "causal":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
    elif model_type == "masked":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForMaskedLM.from_config(config)
    elif model_type == "sequence":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForSequenceClassification.from_config(config)
    elif model_type == "token":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForTokenClassification.from_config(config)
    elif model_type == "qa":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForQuestionAnswering.from_config(config)
    elif model_type == "s2s":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForSeq2SeqLM.from_config(config)
    elif model_type == "vision":
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with init_empty_weights():
            model = AutoModelForImageClassification.from_config(config)
    else:
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        with torch.device("meta"):
            model = AutoModel.from_config(config, trust_remote_code=True)

    structure = {
        "model_type": config.model_type,
        "hidden_size": getattr(config, "hidden_size", None),
        "num_hidden_layers": getattr(config, "num_hidden_layers", None),
        "num_attention_heads": getattr(config, "num_attention_heads", None),
        "image_size": getattr(config, "image_size", None),
        "intermediate_size": getattr(config, "intermediate_size", None),
        "patch_size": getattr(config, "patch_size", None),
        "vocab_size": getattr(config, "vocab_size", None),
        "layers": hf_style_structural_dict(model)
    }

    # 3. Free memory
    del model
    gc.collect()
    torch.cuda.empty_cache()  # only if using GPU

    # 4. Save JSON in cache
    model_structure_cache[model_name] = structure

    return structure  # JSON-serializable