sentinel / apps /streamlit_ui /pages /3_Assessment.py
jeuko's picture
Sync from GitHub (main)
cc034ee verified
"""Streamlit page: Assessment."""
import os
import tempfile
from pathlib import Path
import streamlit as st
import yaml
# Configure page layout to be wider
st.set_page_config(layout="wide")
from collections import Counter
import pandas as pd
import plotly.graph_objects as go
from ui_utils import initialize_session_state
from sentinel.config import AppConfig, ModelConfig, ResourcePaths
from sentinel.conversation import ConversationManager
from sentinel.factory import SentinelFactory
from sentinel.reporting import generate_excel_report, generate_pdf_report
initialize_session_state()
if st.session_state.user_profile is None:
st.warning(
"Please complete your profile on the Profile page before running an assessment."
)
st.stop()
def create_conversation_manager(config: dict) -> ConversationManager:
"""Create a conversation manager from the current configuration.
Args:
config: A dictionary containing the current configuration.
Returns:
ConversationManager: A conversation manager instance.
"""
# Define base paths relative to project root
root = Path(__file__).resolve().parents[3]
# Load model config to get provider and model name
model_config_path = root / "configs" / "model" / f"{config['model']}.yaml"
with open(model_config_path) as f:
model_data = yaml.safe_load(f)
# Create knowledge base paths
knowledge_base_paths = ResourcePaths(
persona=root / "prompts" / "persona" / "default.md",
instruction_assessment=root / "prompts" / "instruction" / "assessment.md",
instruction_conversation=root / "prompts" / "instruction" / "conversation.md",
output_format_assessment=root / "configs" / "output_format" / "assessment.yaml",
output_format_conversation=root
/ "configs"
/ "output_format"
/ "conversation.yaml",
cancer_modules_dir=root / "configs" / "knowledge_base" / "cancer_modules",
dx_protocols_dir=root / "configs" / "knowledge_base" / "dx_protocols",
)
# Create app config
app_config = AppConfig(
model=ModelConfig(
provider=model_data["provider"], model_name=model_data["model_name"]
),
knowledge_base_paths=knowledge_base_paths,
selected_cancer_modules=config.get("cancer_modules", []),
selected_dx_protocols=config.get("dx_protocols", []),
)
# Create factory and conversation manager
factory = SentinelFactory(app_config)
return factory.create_conversation_manager()
manager = create_conversation_manager(st.session_state.config)
st.session_state.conversation_manager = manager
st.title("🔬 Assessment")
if st.button("Run Assessment", type="primary"):
with st.spinner("Running..."):
result = manager.initial_assessment(st.session_state.user_profile)
st.session_state.assessment = result
assessment = st.session_state.get("assessment")
if assessment:
# --- 1. PRE-SORT DATA ---
sorted_risk_assessments = sorted(
assessment.risk_assessments, key=lambda x: x.risk_level or 0, reverse=True
)
sorted_dx_recommendations = sorted(
assessment.dx_recommendations,
key=lambda x: x.recommendation_level or 0,
reverse=True,
)
# --- 2. ROW 1: OVERALL RISK SCORE ---
st.subheader("Overall Risk Score")
if assessment.overall_risk_score is not None:
fig = go.Figure(
go.Indicator(
mode="gauge+number",
value=assessment.overall_risk_score,
title={"text": "Overall Score"},
gauge={"axis": {"range": [0, 100]}},
)
)
fig.update_layout(height=300, margin=dict(t=50, b=40, l=40, r=40))
st.plotly_chart(fig, use_container_width=True)
st.divider()
# --- 3. ROW 2: RISK & RECOMMENDATION CHARTS ---
col1, col2 = st.columns(2)
with col1:
st.subheader("Cancer Risk Levels")
if sorted_risk_assessments:
cancers = [ra.cancer_type for ra in sorted_risk_assessments]
levels = [ra.risk_level or 0 for ra in sorted_risk_assessments]
short_cancers = [c[:28] + "..." if len(c) > 28 else c for c in cancers]
fig = go.Figure(
go.Bar(
x=levels,
y=short_cancers,
orientation="h",
hovertext=cancers,
hovertemplate="<b>%{hovertext}</b><br>Risk Level: %{x}<extra></extra>",
)
)
fig.update_layout(
xaxis=dict(range=[0, 5], title="Risk Level"),
yaxis=dict(autorange="reversed"),
margin=dict(t=20, b=40, l=40, r=40),
)
st.plotly_chart(fig, use_container_width=True)
with col2:
st.subheader("Dx Recommendations")
if sorted_dx_recommendations:
tests = [dx.test_name for dx in sorted_dx_recommendations]
recs = [dx.recommendation_level or 0 for dx in sorted_dx_recommendations]
short_tests = [t[:28] + "..." if len(t) > 28 else t for t in tests]
fig = go.Figure(
go.Bar(
x=recs,
y=short_tests,
orientation="h",
hovertext=tests,
hovertemplate="<b>%{hovertext}</b><br>Recommendation: %{x}<extra></extra>",
)
)
fig.update_layout(
xaxis=dict(range=[0, 5], title="Recommendation"),
yaxis=dict(autorange="reversed"),
margin=dict(t=20, b=40, l=40, r=40),
)
st.plotly_chart(fig, use_container_width=True)
st.divider()
# --- 4. ROW 3: RISK FACTOR VISUALIZATIONS ---
if assessment.identified_risk_factors:
col3, col4 = st.columns(2)
with col3:
st.subheader("Risk Factor Summary")
categories = [
rf.category.value for rf in assessment.identified_risk_factors
]
category_counts = Counter(categories)
pie_fig = go.Figure(
go.Pie(
labels=list(category_counts.keys()),
values=list(category_counts.values()),
hole=0.3,
)
)
pie_fig.update_layout(
height=400,
margin=dict(t=20, b=40, l=40, r=40),
legend=dict(
orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.05
),
)
st.plotly_chart(pie_fig, use_container_width=True)
with col4:
st.subheader("Identified Risk Factors")
risk_factor_data = [
{"Category": rf.category.value, "Description": rf.description}
for rf in assessment.identified_risk_factors
]
rf_df = pd.DataFrame(risk_factor_data)
st.dataframe(rf_df, use_container_width=True, height=400, hide_index=True)
# --- 5. EXPANDERS (using sorted data) ---
with st.expander("Overall Summary"):
st.markdown(assessment.overall_summary, unsafe_allow_html=True)
with st.expander("Calculated Risk Scores (Ground Truth)"):
if assessment.calculated_risk_scores:
st.info(
"These scores have been calculated using validated clinical risk models "
"and represent the authoritative risk assessment."
)
for cancer_type, scores in sorted(
assessment.calculated_risk_scores.items()
):
st.markdown(f"### {cancer_type}")
for score in scores:
st.markdown(f"**{score.name}**: {score.score}")
if score.description:
st.write(f"*{score.description}*")
if score.interpretation:
st.write(score.interpretation)
if score.references:
with st.expander("References"):
for ref in score.references:
st.write(f"- {ref}")
st.divider()
else:
st.write("No risk scores calculated.")
with st.expander("AI-Generated Risk Interpretations"):
for ra in sorted_risk_assessments:
st.markdown(f"**{ra.cancer_type}** - {ra.risk_level or 'N/A'}/5")
st.write(ra.explanation)
if ra.recommended_steps:
st.write("**Recommended Steps:**")
steps = ra.recommended_steps
if isinstance(steps, list):
for step in steps:
st.write(f"- {step}")
else:
st.write(f"- {steps}")
if ra.lifestyle_advice:
st.write(f"*{ra.lifestyle_advice}*")
st.divider()
with st.expander("Dx Recommendations"):
for dx in sorted_dx_recommendations:
st.markdown(f"**{dx.test_name}** - {dx.recommendation_level or 'N/A'}/5")
if dx.frequency:
st.write(f"Frequency: {dx.frequency}")
st.write(dx.rationale)
if dx.applicable_guideline:
st.write(f"Guideline: {dx.applicable_guideline}")
st.divider()
# --- 6. EXISTING DOWNLOAD AND CHAT LOGIC ---
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
generate_pdf_report(assessment, st.session_state.user_profile, f.name)
f.seek(0)
pdf_data = f.read()
st.download_button("Download PDF", pdf_data, file_name="assessment.pdf")
os.unlink(f.name)
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f:
generate_excel_report(assessment, st.session_state.user_profile, f.name)
f.seek(0)
xls_data = f.read()
st.download_button("Download Excel", xls_data, file_name="assessment.xlsx")
# for q, a in manager.history:
# st.chat_message("user").write(q)
# st.chat_message("assistant").write(a)
if question := st.chat_input("Ask a follow-up question"):
with st.spinner("Thinking..."):
resp = manager.follow_up(question)
st.chat_message("user").write(question)
st.chat_message("assistant").write(resp.response)