botbottingbot's picture
Update app.py
425725b verified
import json
import os
import torch
from functools import partial
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline
)
# =============================================================
# LOAD MODULES.JSON
# =============================================================
with open("modules.json", "r", encoding="utf-8") as f:
MODULES = json.load(f)["modules"]
GENERATORS = [m for m in MODULES if m.get("type") == "generator"]
CHECKERS = {m["id"]: m for m in MODULES if m.get("type") == "checker"}
GEN_BY_ID = {m["id"]: m for m in GENERATORS}
LABEL_TO_ID = {m["label"]: m["id"] for m in GENERATORS}
LABEL_LIST = list(LABEL_TO_ID.keys())
# =============================================================
# BASE MODEL (ENGINE) — Can be swapped
# =============================================================
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
llm = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300)
# =============================================================
# HYBRID ROUTER (RULES + ZERO-SHOT)
# =============================================================
# ----------- RULE-BASED ROUTER -----------
RULES = [
("contract", "document_explainer_v1"),
("agreement", "document_explainer_v1"),
("policy", "document_explainer_v1"),
("judgment", "document_explainer_v1"),
("options", "strategy_memo_v1"),
("trade-off", "strategy_memo_v1"),
("recommendation", "strategy_memo_v1"),
("compare", "strategy_memo_v1"),
("system", "system_blueprint_v1"),
("architecture", "system_blueprint_v1"),
("flow", "system_blueprint_v1"),
("analysis", "analysis_note_v1"),
("summarize", "analysis_note_v1"),
("explain", "analysis_note_v1"),
]
def rule_router(text: str):
t = text.lower()
for keyword, module_id in RULES:
if keyword in t:
return module_id
return None
# ----------- ZERO-SHOT ROUTER -----------
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli"
)
def zero_shot_route(text):
res = zero_shot_classifier(text, candidate_labels=LABEL_LIST, multi_label=False)
label = res["labels"][0]
module_id = LABEL_TO_ID[label]
scores = "\n".join([f"{l}: {s:.2f}" for l, s in zip(res["labels"], res["scores"])])
return label, module_id, scores
# ----------- HYBRID ROUTE CALL -----------
def hybrid_route(task: str):
if not task.strip():
return "No input", "", ""
route = rule_router(task)
if route:
return GEN_BY_ID[route]["label"], route, "Rule-based match"
return zero_shot_route(task)
# =============================================================
# DOMAIN HEAD LOADER (LoRA-STYLE ADAPTERS)
# =============================================================
ADAPTER_PATHS = {
"legal": "domain_heads/legal_head.pt",
"strategy": "domain_heads/strategy_head.pt",
"analysis": "domain_heads/analysis_head.pt",
"systems": "domain_heads/systems_head.pt",
}
def load_domain_adapter(domain: str):
if domain not in ADAPTER_PATHS:
return
path = ADAPTER_PATHS[domain]
if not os.path.exists(path):
return
adapter = torch.load(path, map_location="cpu")
with torch.no_grad():
for name, param in model.named_parameters():
if name in adapter:
param += adapter[name]
# =============================================================
# REASONING SCAFFOLDS
# =============================================================
# ----------- CHAIN-OF-THOUGHT -----------
def apply_cot(prompt: str) -> str:
return (
"Think step-by-step. Explain your reasoning before answering.\n\n"
+ prompt
+ "\n\nNow think step-by-step and answer:"
)
# ----------- CRITIQUE + REFINE LOOP -----------
critic = pipeline(
"text-generation",
model="openai-community/gpt2",
max_new_tokens=200,
do_sample=False
)
def critique(text: str) -> str:
prompt = (
"Review this draft. Identify unclear reasoning, gaps, contradictions.\n\n"
"DRAFT:\n" + text + "\n\nReturn critique only:\n"
)
out = critic(prompt)[0]["generated_text"]
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip()
def refine(text: str, critique_text: str) -> str:
prompt = (
"Improve the draft using the critique. Fix gaps, strengthen logic.\n\n"
"CRITIQUE:\n" + critique_text +
"\n\nDRAFT:\n" + text +
"\n\nReturn improved output:\n"
)
out = critic(prompt)[0]["generated_text"]
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip()
def critique_and_refine(text: str) -> str:
c = critique(text)
return refine(text, c)
# =============================================================
# LLM CALL + PROMPT BUILDING
# =============================================================
def call_llm(prompt: str) -> str:
out = llm(prompt, do_sample=False)[0]["generated_text"]
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip()
def build_generator_prompt(module_id: str, *inputs: str) -> str:
m = GEN_BY_ID[module_id]
keys = list(m["input_placeholders"].keys())
vals = {k: inputs[i] if i < len(inputs) else "" for i, k in enumerate(keys)}
secs = m["output_sections"]
p = []
p.append(f"MODULE: {m['label']} (id={module_id})")
p.append("You must follow the structured reasoning format.\n")
p.append("INPUTS:")
for k, v in vals.items():
p.append(f"{k.upper()}: {v}")
p.append("\nOutput sections:")
for s in secs:
p.append(f"- {s}")
p.append("\nFormat exactly as:")
for s in secs:
p.append(f"{s}:\n[content]\n")
return "\n".join(p)
def build_checker_prompt(checker_id: str, *vals: str) -> str:
c = CHECKERS[checker_id]
secs = c["output_sections"]
if len(vals) < 2:
original = ""
draft = vals[0] if vals else ""
else:
original = "\n\n".join(vals[:-1])
draft = vals[-1]
p = []
p.append(f"CHECKER: {c['label']} (id={checker_id})")
p.append("Review for structure, alignment and reasoning quality.\n")
p.append("ORIGINAL TASK:\n" + original + "\n")
p.append("DRAFT OUTPUT:\n" + draft + "\n")
p.append("Sections required:")
for s in secs:
p.append(f"- {s}")
p.append("\nFormat:")
for s in secs:
p.append(f"{s}:\n[content]\n")
return "\n".join(p)
# =============================================================
# GENERATOR + CHECKER EXECUTION
# =============================================================
def run_generator(module_id: str, *inputs: str) -> str:
m = GEN_BY_ID[module_id]
if m.get("domain"):
load_domain_adapter(m["domain"])
prompt = build_generator_prompt(module_id, *inputs)
prompt = apply_cot(prompt)
draft = call_llm(prompt)
final = critique_and_refine(draft)
return final
def run_checker(checker_id: str, *inputs: str) -> str:
prompt = build_checker_prompt(checker_id, *inputs)
prompt = apply_cot(prompt)
return call_llm(prompt)
# =============================================================
# GRADIO UI
# =============================================================
def build_ui():
with gr.Blocks(title="Modular Intelligence — Unified System") as demo:
gr.Markdown("# Modular Intelligence\nUnified architecture with routing, adapters, and reasoning scaffolds.")
# ---------------- AUTO-ROUTE TAB ----------------
with gr.Tab("Auto-Route"):
task_box = gr.Textbox(label="Describe your task", lines=6)
out_name = gr.Textbox(label="Suggested Module", interactive=False)
out_id = gr.Textbox(label="Module ID", interactive=False)
out_scores = gr.Textbox(label="Routing Details", lines=12, interactive=False)
gr.Button("Classify Task").click(
fn=hybrid_route,
inputs=[task_box],
outputs=[out_name, out_id, out_scores],
)
# ---------------- MODULE TABS ----------------
for m in GENERATORS:
with gr.Tab(m["label"]):
gr.Markdown(f"**Module ID:** `{m['id']}` | **Domain:** `{m.get('domain','general')}`")
inputs = []
for key, placeholder in m["input_placeholders"].items():
t = gr.Textbox(label=key, placeholder=placeholder, lines=4)
inputs.append(t)
output = gr.Textbox(label="Generator Output", lines=18)
gr.Button("Run Module").click(
fn=partial(run_generator, m["id"]),
inputs=inputs,
outputs=output,
)
checker_id = m.get("checker_id")
if checker_id in CHECKERS:
check_out = gr.Textbox(label="Checker Output", lines=15)
gr.Button("Run Checker").click(
fn=partial(run_checker, checker_id),
inputs=inputs + [output],
outputs=check_out,
)
else:
gr.Markdown("_No checker for this module._")
return demo
if __name__ == "__main__":
ui = build_ui()
ui.launch()