Spaces:
Running
Running
File size: 3,351 Bytes
c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 a084789 c3173d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import React from "react";
function isNumberString(value) {
return typeof value === "string" && /^[0-9]+$/.test(value);
}
function LayerNode({ node, nodeKey, last, depth = 0 }) {
if (!node) return null;
const name = node.class_name || node.name || "<unknown>";
const shape = node.params?.weight?.shape || [];
// Normalize children: object map
let children = [];
let keys = [];
if (node.children && typeof node.children === "object") {
children = Object.values(node.children);
keys = Object.keys(node.children);
}
// choose background by depth for subtle alternation
const bgByDepth = ["bg-slate-50", "bg-slate-100", "bg-slate-200", "bg-slate-50"];
const bgClass = bgByDepth[depth % bgByDepth.length];
const isEmbed = String(name).toLowerCase().includes("embed");
const embedStyle = isEmbed ? { backgroundColor: "#fbdfe2" } : undefined;
const isAttn = String(name).toLowerCase().includes("attention") || String(name).toLowerCase().includes("attn");
const attnStyle = isAttn ? { backgroundColor: "#fddfba" } : undefined;
const isFFN = String(name).toLowerCase().includes("ffn") || String(name).toLowerCase().includes("mlp");
const ffnStyle = isFFN ? { backgroundColor: "#c2e6f8" } : undefined;
const isNorm = String(name).toLowerCase().includes("norm");
const normStyle = isNorm ? { backgroundColor: "#f3f7c3" } : undefined;
const isLastNorm = last && isNorm;
const lastNormStyle = isLastNorm ? { backgroundColor: "#DEDFF1" } : undefined;
return (
<div
className={`pl-2 ${!isEmbed ? bgClass : ""} rounded-md border border-slate-200 p-2`}
style={embedStyle || attnStyle || ffnStyle || lastNormStyle || normStyle}
>
<div className="flex items-center gap-3">
<div className="text-sm text-slate-800 font-medium">{name}</div>
{nodeKey && !isNumberString(nodeKey) && <div className="text-sm text-slate-800 font-medium">({nodeKey})</div>}
<div className="text-xs text-slate-500">{shape.join(" x ")}</div>
<div>{node.num_repeats && <span className="text-blue-600 font-bold tracking-wide">x {node.num_repeats}</span>}</div>
</div>
{children.length > 0 && (
<div className="pl-4 mt-2">
{children.map((child, idx) => (
<LayerNode key={idx} node={child} nodeKey={keys[idx]} depth={depth + 1} />
))}
</div>
)}
</div>
);
}
export default function ModelLayersCard({ layers = {}, name = "" }) {
let rootNodes = [];
let rootKeys = [];
if (layers && typeof layers === "object") {
if (layers.children && typeof layers.children === "object") {
rootNodes = Object.values(layers.children);
rootKeys = Object.keys(layers.children);
} else {
// Fallback: convert the top-level keyed object into an array
rootNodes = Object.values(layers);
rootKeys = Object.keys(layers);
}
}
if (!rootNodes || rootNodes.length === 0) {
return null;
}
return (
<div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4">
<h2 className="text-lg font-semibold text-slate-800 mb-3">{name}</h2>
<div className="space-y-2">
{rootNodes.map((node, idx) => (
<LayerNode key={idx} node={node} nodeKey={rootKeys[idx]} last={idx == rootNodes.length - 1} />
))}
</div>
</div>
);
}
|