File size: 3,351 Bytes
c3173d9
 
a084789
 
 
c3173d9
a084789
 
c3173d9
 
 
 
 
a084789
c3173d9
 
a084789
c3173d9
 
a084789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3173d9
a084789
 
 
 
 
c3173d9
a084789
c3173d9
a084789
c3173d9
 
 
a084789
c3173d9
a084789
c3173d9
 
 
 
 
 
 
a084789
c3173d9
a084789
c3173d9
 
 
 
a084789
c3173d9
 
 
a084789
c3173d9
 
 
 
 
 
 
 
 
a084789
c3173d9
 
a084789
c3173d9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import React from "react";

function isNumberString(value) {
  return typeof value === "string" && /^[0-9]+$/.test(value);
}

function LayerNode({ node, nodeKey, last, depth = 0 }) {
  if (!node) return null;
  const name = node.class_name || node.name || "<unknown>";
  const shape = node.params?.weight?.shape || [];

  // Normalize children: object map
  let children = [];
  let keys = [];
  if (node.children && typeof node.children === "object") {
    children = Object.values(node.children);
    keys = Object.keys(node.children);
  }

  // choose background by depth for subtle alternation
  const bgByDepth = ["bg-slate-50", "bg-slate-100", "bg-slate-200", "bg-slate-50"];
  const bgClass = bgByDepth[depth % bgByDepth.length];
  const isEmbed = String(name).toLowerCase().includes("embed");
  const embedStyle = isEmbed ? { backgroundColor: "#fbdfe2" } : undefined;
  const isAttn = String(name).toLowerCase().includes("attention") || String(name).toLowerCase().includes("attn");
  const attnStyle = isAttn ? { backgroundColor: "#fddfba" } : undefined;
  const isFFN = String(name).toLowerCase().includes("ffn") || String(name).toLowerCase().includes("mlp");
  const ffnStyle = isFFN ? { backgroundColor: "#c2e6f8" } : undefined;
  const isNorm = String(name).toLowerCase().includes("norm");
  const normStyle = isNorm ? { backgroundColor: "#f3f7c3" } : undefined;
  const isLastNorm = last && isNorm;
  const lastNormStyle = isLastNorm ? { backgroundColor: "#DEDFF1" } : undefined;


  return (
    <div
      className={`pl-2 ${!isEmbed ? bgClass : ""} rounded-md border border-slate-200 p-2`} 
      style={embedStyle || attnStyle || ffnStyle || lastNormStyle || normStyle}
    > 
      <div className="flex items-center gap-3">
        <div className="text-sm text-slate-800 font-medium">{name}</div>
        {nodeKey && !isNumberString(nodeKey) && <div className="text-sm text-slate-800 font-medium">({nodeKey})</div>}
        <div className="text-xs text-slate-500">{shape.join(" x ")}</div>
        <div>{node.num_repeats && <span className="text-blue-600 font-bold tracking-wide">x {node.num_repeats}</span>}</div>
      </div>

      {children.length > 0 && (
        <div className="pl-4 mt-2">
          {children.map((child, idx) => (
            <LayerNode key={idx} node={child} nodeKey={keys[idx]} depth={depth + 1} />
          ))}
        </div>
      )}
    </div>
  );
}

export default function ModelLayersCard({ layers = {}, name = "" }) {
  let rootNodes = [];
  let rootKeys = [];

  if (layers && typeof layers === "object") {
    if (layers.children && typeof layers.children === "object") {
      rootNodes = Object.values(layers.children);
      rootKeys = Object.keys(layers.children);
    } else {
      // Fallback: convert the top-level keyed object into an array
      rootNodes = Object.values(layers);
      rootKeys = Object.keys(layers);
    }
  }

  if (!rootNodes || rootNodes.length === 0) {
    return null;
  }

  return (
    <div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4">
      <h2 className="text-lg font-semibold text-slate-800 mb-3">{name}</h2>
      <div className="space-y-2">
        {rootNodes.map((node, idx) => (
          <LayerNode key={idx} node={node} nodeKey={rootKeys[idx]} last={idx == rootNodes.length - 1} />
        ))}
      </div>
    </div>
  );
}