Spaces:
Running
Running
| import React from "react"; | |
| function isNumberString(value) { | |
| return typeof value === "string" && /^[0-9]+$/.test(value); | |
| } | |
| function LayerNode({ node, nodeKey, last, depth = 0 }) { | |
| if (!node) return null; | |
| const name = node.class_name || node.name || "<unknown>"; | |
| const shape = node.params?.weight?.shape || []; | |
| // Normalize children: object map | |
| let children = []; | |
| let keys = []; | |
| if (node.children && typeof node.children === "object") { | |
| children = Object.values(node.children); | |
| keys = Object.keys(node.children); | |
| } | |
| // choose background by depth for subtle alternation | |
| const bgByDepth = ["bg-slate-50", "bg-slate-100", "bg-slate-200", "bg-slate-50"]; | |
| const bgClass = bgByDepth[depth % bgByDepth.length]; | |
| const isEmbed = String(name).toLowerCase().includes("embed"); | |
| const embedStyle = isEmbed ? { backgroundColor: "#fbdfe2" } : undefined; | |
| const isAttn = String(name).toLowerCase().includes("attention") || String(name).toLowerCase().includes("attn"); | |
| const attnStyle = isAttn ? { backgroundColor: "#fddfba" } : undefined; | |
| const isFFN = String(name).toLowerCase().includes("ffn") || String(name).toLowerCase().includes("mlp"); | |
| const ffnStyle = isFFN ? { backgroundColor: "#c2e6f8" } : undefined; | |
| const isNorm = String(name).toLowerCase().includes("norm"); | |
| const normStyle = isNorm ? { backgroundColor: "#f3f7c3" } : undefined; | |
| const isLastNorm = last && isNorm; | |
| const lastNormStyle = isLastNorm ? { backgroundColor: "#DEDFF1" } : undefined; | |
| return ( | |
| <div | |
| className={`pl-2 ${!isEmbed ? bgClass : ""} rounded-md border border-slate-200 p-2`} | |
| style={embedStyle || attnStyle || ffnStyle || lastNormStyle || normStyle} | |
| > | |
| <div className="flex items-center gap-3"> | |
| <div className="text-sm text-slate-800 font-medium">{name}</div> | |
| {nodeKey && !isNumberString(nodeKey) && <div className="text-sm text-slate-800 font-medium">({nodeKey})</div>} | |
| <div className="text-xs text-slate-500">{shape.join(" x ")}</div> | |
| <div>{node.num_repeats && <span className="text-blue-600 font-bold tracking-wide">x {node.num_repeats}</span>}</div> | |
| </div> | |
| {children.length > 0 && ( | |
| <div className="pl-4 mt-2"> | |
| {children.map((child, idx) => ( | |
| <LayerNode key={idx} node={child} nodeKey={keys[idx]} depth={depth + 1} /> | |
| ))} | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } | |
| export default function ModelLayersCard({ layers = {}, name = "" }) { | |
| let rootNodes = []; | |
| let rootKeys = []; | |
| if (layers && typeof layers === "object") { | |
| if (layers.children && typeof layers.children === "object") { | |
| rootNodes = Object.values(layers.children); | |
| rootKeys = Object.keys(layers.children); | |
| } else { | |
| // Fallback: convert the top-level keyed object into an array | |
| rootNodes = Object.values(layers); | |
| rootKeys = Object.keys(layers); | |
| } | |
| } | |
| if (!rootNodes || rootNodes.length === 0) { | |
| return null; | |
| } | |
| return ( | |
| <div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4"> | |
| <h2 className="text-lg font-semibold text-slate-800 mb-3">{name}</h2> | |
| <div className="space-y-2"> | |
| {rootNodes.map((node, idx) => ( | |
| <LayerNode key={idx} node={node} nodeKey={rootKeys[idx]} last={idx == rootNodes.length - 1} /> | |
| ))} | |
| </div> | |
| </div> | |
| ); | |
| } | |