maomao88's picture
display model structure
a084789
import React from "react";
function isNumberString(value) {
return typeof value === "string" && /^[0-9]+$/.test(value);
}
function LayerNode({ node, nodeKey, last, depth = 0 }) {
if (!node) return null;
const name = node.class_name || node.name || "<unknown>";
const shape = node.params?.weight?.shape || [];
// Normalize children: object map
let children = [];
let keys = [];
if (node.children && typeof node.children === "object") {
children = Object.values(node.children);
keys = Object.keys(node.children);
}
// choose background by depth for subtle alternation
const bgByDepth = ["bg-slate-50", "bg-slate-100", "bg-slate-200", "bg-slate-50"];
const bgClass = bgByDepth[depth % bgByDepth.length];
const isEmbed = String(name).toLowerCase().includes("embed");
const embedStyle = isEmbed ? { backgroundColor: "#fbdfe2" } : undefined;
const isAttn = String(name).toLowerCase().includes("attention") || String(name).toLowerCase().includes("attn");
const attnStyle = isAttn ? { backgroundColor: "#fddfba" } : undefined;
const isFFN = String(name).toLowerCase().includes("ffn") || String(name).toLowerCase().includes("mlp");
const ffnStyle = isFFN ? { backgroundColor: "#c2e6f8" } : undefined;
const isNorm = String(name).toLowerCase().includes("norm");
const normStyle = isNorm ? { backgroundColor: "#f3f7c3" } : undefined;
const isLastNorm = last && isNorm;
const lastNormStyle = isLastNorm ? { backgroundColor: "#DEDFF1" } : undefined;
return (
<div
className={`pl-2 ${!isEmbed ? bgClass : ""} rounded-md border border-slate-200 p-2`}
style={embedStyle || attnStyle || ffnStyle || lastNormStyle || normStyle}
>
<div className="flex items-center gap-3">
<div className="text-sm text-slate-800 font-medium">{name}</div>
{nodeKey && !isNumberString(nodeKey) && <div className="text-sm text-slate-800 font-medium">({nodeKey})</div>}
<div className="text-xs text-slate-500">{shape.join(" x ")}</div>
<div>{node.num_repeats && <span className="text-blue-600 font-bold tracking-wide">x {node.num_repeats}</span>}</div>
</div>
{children.length > 0 && (
<div className="pl-4 mt-2">
{children.map((child, idx) => (
<LayerNode key={idx} node={child} nodeKey={keys[idx]} depth={depth + 1} />
))}
</div>
)}
</div>
);
}
export default function ModelLayersCard({ layers = {}, name = "" }) {
let rootNodes = [];
let rootKeys = [];
if (layers && typeof layers === "object") {
if (layers.children && typeof layers.children === "object") {
rootNodes = Object.values(layers.children);
rootKeys = Object.keys(layers.children);
} else {
// Fallback: convert the top-level keyed object into an array
rootNodes = Object.values(layers);
rootKeys = Object.keys(layers);
}
}
if (!rootNodes || rootNodes.length === 0) {
return null;
}
return (
<div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4">
<h2 className="text-lg font-semibold text-slate-800 mb-3">{name}</h2>
<div className="space-y-2">
{rootNodes.map((node, idx) => (
<LayerNode key={idx} node={node} nodeKey={rootKeys[idx]} last={idx == rootNodes.length - 1} />
))}
</div>
</div>
);
}