import json import os import torch from functools import partial import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline ) # ============================================================= # LOAD MODULES.JSON # ============================================================= with open("modules.json", "r", encoding="utf-8") as f: MODULES = json.load(f)["modules"] GENERATORS = [m for m in MODULES if m.get("type") == "generator"] CHECKERS = {m["id"]: m for m in MODULES if m.get("type") == "checker"} GEN_BY_ID = {m["id"]: m for m in GENERATORS} LABEL_TO_ID = {m["label"]: m["id"] for m in GENERATORS} LABEL_LIST = list(LABEL_TO_ID.keys()) # ============================================================= # BASE MODEL (ENGINE) — Can be swapped # ============================================================= tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") llm = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300) # ============================================================= # HYBRID ROUTER (RULES + ZERO-SHOT) # ============================================================= # ----------- RULE-BASED ROUTER ----------- RULES = [ ("contract", "document_explainer_v1"), ("agreement", "document_explainer_v1"), ("policy", "document_explainer_v1"), ("judgment", "document_explainer_v1"), ("options", "strategy_memo_v1"), ("trade-off", "strategy_memo_v1"), ("recommendation", "strategy_memo_v1"), ("compare", "strategy_memo_v1"), ("system", "system_blueprint_v1"), ("architecture", "system_blueprint_v1"), ("flow", "system_blueprint_v1"), ("analysis", "analysis_note_v1"), ("summarize", "analysis_note_v1"), ("explain", "analysis_note_v1"), ] def rule_router(text: str): t = text.lower() for keyword, module_id in RULES: if keyword in t: return module_id return None # ----------- ZERO-SHOT ROUTER ----------- zero_shot_classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli" ) def zero_shot_route(text): res = zero_shot_classifier(text, candidate_labels=LABEL_LIST, multi_label=False) label = res["labels"][0] module_id = LABEL_TO_ID[label] scores = "\n".join([f"{l}: {s:.2f}" for l, s in zip(res["labels"], res["scores"])]) return label, module_id, scores # ----------- HYBRID ROUTE CALL ----------- def hybrid_route(task: str): if not task.strip(): return "No input", "", "" route = rule_router(task) if route: return GEN_BY_ID[route]["label"], route, "Rule-based match" return zero_shot_route(task) # ============================================================= # DOMAIN HEAD LOADER (LoRA-STYLE ADAPTERS) # ============================================================= ADAPTER_PATHS = { "legal": "domain_heads/legal_head.pt", "strategy": "domain_heads/strategy_head.pt", "analysis": "domain_heads/analysis_head.pt", "systems": "domain_heads/systems_head.pt", } def load_domain_adapter(domain: str): if domain not in ADAPTER_PATHS: return path = ADAPTER_PATHS[domain] if not os.path.exists(path): return adapter = torch.load(path, map_location="cpu") with torch.no_grad(): for name, param in model.named_parameters(): if name in adapter: param += adapter[name] # ============================================================= # REASONING SCAFFOLDS # ============================================================= # ----------- CHAIN-OF-THOUGHT ----------- def apply_cot(prompt: str) -> str: return ( "Think step-by-step. Explain your reasoning before answering.\n\n" + prompt + "\n\nNow think step-by-step and answer:" ) # ----------- CRITIQUE + REFINE LOOP ----------- critic = pipeline( "text-generation", model="openai-community/gpt2", max_new_tokens=200, do_sample=False ) def critique(text: str) -> str: prompt = ( "Review this draft. Identify unclear reasoning, gaps, contradictions.\n\n" "DRAFT:\n" + text + "\n\nReturn critique only:\n" ) out = critic(prompt)[0]["generated_text"] return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() def refine(text: str, critique_text: str) -> str: prompt = ( "Improve the draft using the critique. Fix gaps, strengthen logic.\n\n" "CRITIQUE:\n" + critique_text + "\n\nDRAFT:\n" + text + "\n\nReturn improved output:\n" ) out = critic(prompt)[0]["generated_text"] return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() def critique_and_refine(text: str) -> str: c = critique(text) return refine(text, c) # ============================================================= # LLM CALL + PROMPT BUILDING # ============================================================= def call_llm(prompt: str) -> str: out = llm(prompt, do_sample=False)[0]["generated_text"] return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() def build_generator_prompt(module_id: str, *inputs: str) -> str: m = GEN_BY_ID[module_id] keys = list(m["input_placeholders"].keys()) vals = {k: inputs[i] if i < len(inputs) else "" for i, k in enumerate(keys)} secs = m["output_sections"] p = [] p.append(f"MODULE: {m['label']} (id={module_id})") p.append("You must follow the structured reasoning format.\n") p.append("INPUTS:") for k, v in vals.items(): p.append(f"{k.upper()}: {v}") p.append("\nOutput sections:") for s in secs: p.append(f"- {s}") p.append("\nFormat exactly as:") for s in secs: p.append(f"{s}:\n[content]\n") return "\n".join(p) def build_checker_prompt(checker_id: str, *vals: str) -> str: c = CHECKERS[checker_id] secs = c["output_sections"] if len(vals) < 2: original = "" draft = vals[0] if vals else "" else: original = "\n\n".join(vals[:-1]) draft = vals[-1] p = [] p.append(f"CHECKER: {c['label']} (id={checker_id})") p.append("Review for structure, alignment and reasoning quality.\n") p.append("ORIGINAL TASK:\n" + original + "\n") p.append("DRAFT OUTPUT:\n" + draft + "\n") p.append("Sections required:") for s in secs: p.append(f"- {s}") p.append("\nFormat:") for s in secs: p.append(f"{s}:\n[content]\n") return "\n".join(p) # ============================================================= # GENERATOR + CHECKER EXECUTION # ============================================================= def run_generator(module_id: str, *inputs: str) -> str: m = GEN_BY_ID[module_id] if m.get("domain"): load_domain_adapter(m["domain"]) prompt = build_generator_prompt(module_id, *inputs) prompt = apply_cot(prompt) draft = call_llm(prompt) final = critique_and_refine(draft) return final def run_checker(checker_id: str, *inputs: str) -> str: prompt = build_checker_prompt(checker_id, *inputs) prompt = apply_cot(prompt) return call_llm(prompt) # ============================================================= # GRADIO UI # ============================================================= def build_ui(): with gr.Blocks(title="Modular Intelligence — Unified System") as demo: gr.Markdown("# Modular Intelligence\nUnified architecture with routing, adapters, and reasoning scaffolds.") # ---------------- AUTO-ROUTE TAB ---------------- with gr.Tab("Auto-Route"): task_box = gr.Textbox(label="Describe your task", lines=6) out_name = gr.Textbox(label="Suggested Module", interactive=False) out_id = gr.Textbox(label="Module ID", interactive=False) out_scores = gr.Textbox(label="Routing Details", lines=12, interactive=False) gr.Button("Classify Task").click( fn=hybrid_route, inputs=[task_box], outputs=[out_name, out_id, out_scores], ) # ---------------- MODULE TABS ---------------- for m in GENERATORS: with gr.Tab(m["label"]): gr.Markdown(f"**Module ID:** `{m['id']}` | **Domain:** `{m.get('domain','general')}`") inputs = [] for key, placeholder in m["input_placeholders"].items(): t = gr.Textbox(label=key, placeholder=placeholder, lines=4) inputs.append(t) output = gr.Textbox(label="Generator Output", lines=18) gr.Button("Run Module").click( fn=partial(run_generator, m["id"]), inputs=inputs, outputs=output, ) checker_id = m.get("checker_id") if checker_id in CHECKERS: check_out = gr.Textbox(label="Checker Output", lines=15) gr.Button("Run Checker").click( fn=partial(run_checker, checker_id), inputs=inputs + [output], outputs=check_out, ) else: gr.Markdown("_No checker for this module._") return demo if __name__ == "__main__": ui = build_ui() ui.launch()