|
|
import json |
|
|
import os |
|
|
import torch |
|
|
from functools import partial |
|
|
import gradio as gr |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
pipeline |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("modules.json", "r", encoding="utf-8") as f: |
|
|
MODULES = json.load(f)["modules"] |
|
|
|
|
|
GENERATORS = [m for m in MODULES if m.get("type") == "generator"] |
|
|
CHECKERS = {m["id"]: m for m in MODULES if m.get("type") == "checker"} |
|
|
GEN_BY_ID = {m["id"]: m for m in GENERATORS} |
|
|
|
|
|
LABEL_TO_ID = {m["label"]: m["id"] for m in GENERATORS} |
|
|
LABEL_LIST = list(LABEL_TO_ID.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") |
|
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") |
|
|
llm = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RULES = [ |
|
|
("contract", "document_explainer_v1"), |
|
|
("agreement", "document_explainer_v1"), |
|
|
("policy", "document_explainer_v1"), |
|
|
("judgment", "document_explainer_v1"), |
|
|
("options", "strategy_memo_v1"), |
|
|
("trade-off", "strategy_memo_v1"), |
|
|
("recommendation", "strategy_memo_v1"), |
|
|
("compare", "strategy_memo_v1"), |
|
|
("system", "system_blueprint_v1"), |
|
|
("architecture", "system_blueprint_v1"), |
|
|
("flow", "system_blueprint_v1"), |
|
|
("analysis", "analysis_note_v1"), |
|
|
("summarize", "analysis_note_v1"), |
|
|
("explain", "analysis_note_v1"), |
|
|
] |
|
|
|
|
|
def rule_router(text: str): |
|
|
t = text.lower() |
|
|
for keyword, module_id in RULES: |
|
|
if keyword in t: |
|
|
return module_id |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
zero_shot_classifier = pipeline( |
|
|
"zero-shot-classification", |
|
|
model="facebook/bart-large-mnli" |
|
|
) |
|
|
|
|
|
def zero_shot_route(text): |
|
|
res = zero_shot_classifier(text, candidate_labels=LABEL_LIST, multi_label=False) |
|
|
label = res["labels"][0] |
|
|
module_id = LABEL_TO_ID[label] |
|
|
scores = "\n".join([f"{l}: {s:.2f}" for l, s in zip(res["labels"], res["scores"])]) |
|
|
return label, module_id, scores |
|
|
|
|
|
|
|
|
|
|
|
def hybrid_route(task: str): |
|
|
if not task.strip(): |
|
|
return "No input", "", "" |
|
|
|
|
|
route = rule_router(task) |
|
|
if route: |
|
|
return GEN_BY_ID[route]["label"], route, "Rule-based match" |
|
|
|
|
|
return zero_shot_route(task) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ADAPTER_PATHS = { |
|
|
"legal": "domain_heads/legal_head.pt", |
|
|
"strategy": "domain_heads/strategy_head.pt", |
|
|
"analysis": "domain_heads/analysis_head.pt", |
|
|
"systems": "domain_heads/systems_head.pt", |
|
|
} |
|
|
|
|
|
def load_domain_adapter(domain: str): |
|
|
if domain not in ADAPTER_PATHS: |
|
|
return |
|
|
|
|
|
path = ADAPTER_PATHS[domain] |
|
|
if not os.path.exists(path): |
|
|
return |
|
|
|
|
|
adapter = torch.load(path, map_location="cpu") |
|
|
with torch.no_grad(): |
|
|
for name, param in model.named_parameters(): |
|
|
if name in adapter: |
|
|
param += adapter[name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_cot(prompt: str) -> str: |
|
|
return ( |
|
|
"Think step-by-step. Explain your reasoning before answering.\n\n" |
|
|
+ prompt |
|
|
+ "\n\nNow think step-by-step and answer:" |
|
|
) |
|
|
|
|
|
|
|
|
critic = pipeline( |
|
|
"text-generation", |
|
|
model="openai-community/gpt2", |
|
|
max_new_tokens=200, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
def critique(text: str) -> str: |
|
|
prompt = ( |
|
|
"Review this draft. Identify unclear reasoning, gaps, contradictions.\n\n" |
|
|
"DRAFT:\n" + text + "\n\nReturn critique only:\n" |
|
|
) |
|
|
out = critic(prompt)[0]["generated_text"] |
|
|
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() |
|
|
|
|
|
def refine(text: str, critique_text: str) -> str: |
|
|
prompt = ( |
|
|
"Improve the draft using the critique. Fix gaps, strengthen logic.\n\n" |
|
|
"CRITIQUE:\n" + critique_text + |
|
|
"\n\nDRAFT:\n" + text + |
|
|
"\n\nReturn improved output:\n" |
|
|
) |
|
|
out = critic(prompt)[0]["generated_text"] |
|
|
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() |
|
|
|
|
|
def critique_and_refine(text: str) -> str: |
|
|
c = critique(text) |
|
|
return refine(text, c) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_llm(prompt: str) -> str: |
|
|
out = llm(prompt, do_sample=False)[0]["generated_text"] |
|
|
return out[len(prompt):].strip() if out.startswith(prompt) else out.strip() |
|
|
|
|
|
|
|
|
def build_generator_prompt(module_id: str, *inputs: str) -> str: |
|
|
m = GEN_BY_ID[module_id] |
|
|
keys = list(m["input_placeholders"].keys()) |
|
|
vals = {k: inputs[i] if i < len(inputs) else "" for i, k in enumerate(keys)} |
|
|
secs = m["output_sections"] |
|
|
|
|
|
p = [] |
|
|
p.append(f"MODULE: {m['label']} (id={module_id})") |
|
|
p.append("You must follow the structured reasoning format.\n") |
|
|
p.append("INPUTS:") |
|
|
for k, v in vals.items(): |
|
|
p.append(f"{k.upper()}: {v}") |
|
|
p.append("\nOutput sections:") |
|
|
for s in secs: |
|
|
p.append(f"- {s}") |
|
|
p.append("\nFormat exactly as:") |
|
|
for s in secs: |
|
|
p.append(f"{s}:\n[content]\n") |
|
|
return "\n".join(p) |
|
|
|
|
|
|
|
|
def build_checker_prompt(checker_id: str, *vals: str) -> str: |
|
|
c = CHECKERS[checker_id] |
|
|
secs = c["output_sections"] |
|
|
|
|
|
if len(vals) < 2: |
|
|
original = "" |
|
|
draft = vals[0] if vals else "" |
|
|
else: |
|
|
original = "\n\n".join(vals[:-1]) |
|
|
draft = vals[-1] |
|
|
|
|
|
p = [] |
|
|
p.append(f"CHECKER: {c['label']} (id={checker_id})") |
|
|
p.append("Review for structure, alignment and reasoning quality.\n") |
|
|
p.append("ORIGINAL TASK:\n" + original + "\n") |
|
|
p.append("DRAFT OUTPUT:\n" + draft + "\n") |
|
|
p.append("Sections required:") |
|
|
for s in secs: |
|
|
p.append(f"- {s}") |
|
|
p.append("\nFormat:") |
|
|
for s in secs: |
|
|
p.append(f"{s}:\n[content]\n") |
|
|
return "\n".join(p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_generator(module_id: str, *inputs: str) -> str: |
|
|
m = GEN_BY_ID[module_id] |
|
|
|
|
|
if m.get("domain"): |
|
|
load_domain_adapter(m["domain"]) |
|
|
|
|
|
prompt = build_generator_prompt(module_id, *inputs) |
|
|
prompt = apply_cot(prompt) |
|
|
draft = call_llm(prompt) |
|
|
final = critique_and_refine(draft) |
|
|
return final |
|
|
|
|
|
|
|
|
def run_checker(checker_id: str, *inputs: str) -> str: |
|
|
prompt = build_checker_prompt(checker_id, *inputs) |
|
|
prompt = apply_cot(prompt) |
|
|
return call_llm(prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="Modular Intelligence — Unified System") as demo: |
|
|
|
|
|
gr.Markdown("# Modular Intelligence\nUnified architecture with routing, adapters, and reasoning scaffolds.") |
|
|
|
|
|
|
|
|
with gr.Tab("Auto-Route"): |
|
|
task_box = gr.Textbox(label="Describe your task", lines=6) |
|
|
out_name = gr.Textbox(label="Suggested Module", interactive=False) |
|
|
out_id = gr.Textbox(label="Module ID", interactive=False) |
|
|
out_scores = gr.Textbox(label="Routing Details", lines=12, interactive=False) |
|
|
|
|
|
gr.Button("Classify Task").click( |
|
|
fn=hybrid_route, |
|
|
inputs=[task_box], |
|
|
outputs=[out_name, out_id, out_scores], |
|
|
) |
|
|
|
|
|
|
|
|
for m in GENERATORS: |
|
|
with gr.Tab(m["label"]): |
|
|
gr.Markdown(f"**Module ID:** `{m['id']}` | **Domain:** `{m.get('domain','general')}`") |
|
|
|
|
|
inputs = [] |
|
|
for key, placeholder in m["input_placeholders"].items(): |
|
|
t = gr.Textbox(label=key, placeholder=placeholder, lines=4) |
|
|
inputs.append(t) |
|
|
|
|
|
output = gr.Textbox(label="Generator Output", lines=18) |
|
|
gr.Button("Run Module").click( |
|
|
fn=partial(run_generator, m["id"]), |
|
|
inputs=inputs, |
|
|
outputs=output, |
|
|
) |
|
|
|
|
|
checker_id = m.get("checker_id") |
|
|
if checker_id in CHECKERS: |
|
|
check_out = gr.Textbox(label="Checker Output", lines=15) |
|
|
gr.Button("Run Checker").click( |
|
|
fn=partial(run_checker, checker_id), |
|
|
inputs=inputs + [output], |
|
|
outputs=check_out, |
|
|
) |
|
|
else: |
|
|
gr.Markdown("_No checker for this module._") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
ui = build_ui() |
|
|
ui.launch() |