Spaces:
Runtime error
Runtime error
| """Streamlit page: Assessment.""" | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import streamlit as st | |
| import yaml | |
| # Configure page layout to be wider | |
| st.set_page_config(layout="wide") | |
| from collections import Counter | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from ui_utils import initialize_session_state | |
| from sentinel.config import AppConfig, ModelConfig, ResourcePaths | |
| from sentinel.conversation import ConversationManager | |
| from sentinel.factory import SentinelFactory | |
| from sentinel.reporting import generate_excel_report, generate_pdf_report | |
| initialize_session_state() | |
| if st.session_state.user_profile is None: | |
| st.warning( | |
| "Please complete your profile on the Profile page before running an assessment." | |
| ) | |
| st.stop() | |
| def create_conversation_manager(config: dict) -> ConversationManager: | |
| """Create a conversation manager from the current configuration. | |
| Args: | |
| config: A dictionary containing the current configuration. | |
| Returns: | |
| ConversationManager: A conversation manager instance. | |
| """ | |
| # Define base paths relative to project root | |
| root = Path(__file__).resolve().parents[3] | |
| # Load model config to get provider and model name | |
| model_config_path = root / "configs" / "model" / f"{config['model']}.yaml" | |
| with open(model_config_path) as f: | |
| model_data = yaml.safe_load(f) | |
| # Create knowledge base paths | |
| knowledge_base_paths = ResourcePaths( | |
| persona=root / "prompts" / "persona" / "default.md", | |
| instruction_assessment=root / "prompts" / "instruction" / "assessment.md", | |
| instruction_conversation=root / "prompts" / "instruction" / "conversation.md", | |
| output_format_assessment=root / "configs" / "output_format" / "assessment.yaml", | |
| output_format_conversation=root | |
| / "configs" | |
| / "output_format" | |
| / "conversation.yaml", | |
| cancer_modules_dir=root / "configs" / "knowledge_base" / "cancer_modules", | |
| dx_protocols_dir=root / "configs" / "knowledge_base" / "dx_protocols", | |
| ) | |
| # Create app config | |
| app_config = AppConfig( | |
| model=ModelConfig( | |
| provider=model_data["provider"], model_name=model_data["model_name"] | |
| ), | |
| knowledge_base_paths=knowledge_base_paths, | |
| selected_cancer_modules=config.get("cancer_modules", []), | |
| selected_dx_protocols=config.get("dx_protocols", []), | |
| ) | |
| # Create factory and conversation manager | |
| factory = SentinelFactory(app_config) | |
| return factory.create_conversation_manager() | |
| manager = create_conversation_manager(st.session_state.config) | |
| st.session_state.conversation_manager = manager | |
| st.title("🔬 Assessment") | |
| if st.button("Run Assessment", type="primary"): | |
| with st.spinner("Running..."): | |
| result = manager.initial_assessment(st.session_state.user_profile) | |
| st.session_state.assessment = result | |
| assessment = st.session_state.get("assessment") | |
| if assessment: | |
| # --- 1. PRE-SORT DATA --- | |
| sorted_risk_assessments = sorted( | |
| assessment.risk_assessments, key=lambda x: x.risk_level or 0, reverse=True | |
| ) | |
| sorted_dx_recommendations = sorted( | |
| assessment.dx_recommendations, | |
| key=lambda x: x.recommendation_level or 0, | |
| reverse=True, | |
| ) | |
| # --- 2. ROW 1: OVERALL RISK SCORE --- | |
| st.subheader("Overall Risk Score") | |
| if assessment.overall_risk_score is not None: | |
| fig = go.Figure( | |
| go.Indicator( | |
| mode="gauge+number", | |
| value=assessment.overall_risk_score, | |
| title={"text": "Overall Score"}, | |
| gauge={"axis": {"range": [0, 100]}}, | |
| ) | |
| ) | |
| fig.update_layout(height=300, margin=dict(t=50, b=40, l=40, r=40)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.divider() | |
| # --- 3. ROW 2: RISK & RECOMMENDATION CHARTS --- | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Cancer Risk Levels") | |
| if sorted_risk_assessments: | |
| cancers = [ra.cancer_type for ra in sorted_risk_assessments] | |
| levels = [ra.risk_level or 0 for ra in sorted_risk_assessments] | |
| short_cancers = [c[:28] + "..." if len(c) > 28 else c for c in cancers] | |
| fig = go.Figure( | |
| go.Bar( | |
| x=levels, | |
| y=short_cancers, | |
| orientation="h", | |
| hovertext=cancers, | |
| hovertemplate="<b>%{hovertext}</b><br>Risk Level: %{x}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| xaxis=dict(range=[0, 5], title="Risk Level"), | |
| yaxis=dict(autorange="reversed"), | |
| margin=dict(t=20, b=40, l=40, r=40), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with col2: | |
| st.subheader("Dx Recommendations") | |
| if sorted_dx_recommendations: | |
| tests = [dx.test_name for dx in sorted_dx_recommendations] | |
| recs = [dx.recommendation_level or 0 for dx in sorted_dx_recommendations] | |
| short_tests = [t[:28] + "..." if len(t) > 28 else t for t in tests] | |
| fig = go.Figure( | |
| go.Bar( | |
| x=recs, | |
| y=short_tests, | |
| orientation="h", | |
| hovertext=tests, | |
| hovertemplate="<b>%{hovertext}</b><br>Recommendation: %{x}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| xaxis=dict(range=[0, 5], title="Recommendation"), | |
| yaxis=dict(autorange="reversed"), | |
| margin=dict(t=20, b=40, l=40, r=40), | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.divider() | |
| # --- 4. ROW 3: RISK FACTOR VISUALIZATIONS --- | |
| if assessment.identified_risk_factors: | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| st.subheader("Risk Factor Summary") | |
| categories = [ | |
| rf.category.value for rf in assessment.identified_risk_factors | |
| ] | |
| category_counts = Counter(categories) | |
| pie_fig = go.Figure( | |
| go.Pie( | |
| labels=list(category_counts.keys()), | |
| values=list(category_counts.values()), | |
| hole=0.3, | |
| ) | |
| ) | |
| pie_fig.update_layout( | |
| height=400, | |
| margin=dict(t=20, b=40, l=40, r=40), | |
| legend=dict( | |
| orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.05 | |
| ), | |
| ) | |
| st.plotly_chart(pie_fig, use_container_width=True) | |
| with col4: | |
| st.subheader("Identified Risk Factors") | |
| risk_factor_data = [ | |
| {"Category": rf.category.value, "Description": rf.description} | |
| for rf in assessment.identified_risk_factors | |
| ] | |
| rf_df = pd.DataFrame(risk_factor_data) | |
| st.dataframe(rf_df, use_container_width=True, height=400, hide_index=True) | |
| # --- 5. EXPANDERS (using sorted data) --- | |
| with st.expander("Overall Summary"): | |
| st.markdown(assessment.overall_summary, unsafe_allow_html=True) | |
| with st.expander("Calculated Risk Scores (Ground Truth)"): | |
| if assessment.calculated_risk_scores: | |
| st.info( | |
| "These scores have been calculated using validated clinical risk models " | |
| "and represent the authoritative risk assessment." | |
| ) | |
| for cancer_type, scores in sorted( | |
| assessment.calculated_risk_scores.items() | |
| ): | |
| st.markdown(f"### {cancer_type}") | |
| for score in scores: | |
| st.markdown(f"**{score.name}**: {score.score}") | |
| if score.description: | |
| st.write(f"*{score.description}*") | |
| if score.interpretation: | |
| st.write(score.interpretation) | |
| if score.references: | |
| with st.expander("References"): | |
| for ref in score.references: | |
| st.write(f"- {ref}") | |
| st.divider() | |
| else: | |
| st.write("No risk scores calculated.") | |
| with st.expander("AI-Generated Risk Interpretations"): | |
| for ra in sorted_risk_assessments: | |
| st.markdown(f"**{ra.cancer_type}** - {ra.risk_level or 'N/A'}/5") | |
| st.write(ra.explanation) | |
| if ra.recommended_steps: | |
| st.write("**Recommended Steps:**") | |
| steps = ra.recommended_steps | |
| if isinstance(steps, list): | |
| for step in steps: | |
| st.write(f"- {step}") | |
| else: | |
| st.write(f"- {steps}") | |
| if ra.lifestyle_advice: | |
| st.write(f"*{ra.lifestyle_advice}*") | |
| st.divider() | |
| with st.expander("Dx Recommendations"): | |
| for dx in sorted_dx_recommendations: | |
| st.markdown(f"**{dx.test_name}** - {dx.recommendation_level or 'N/A'}/5") | |
| if dx.frequency: | |
| st.write(f"Frequency: {dx.frequency}") | |
| st.write(dx.rationale) | |
| if dx.applicable_guideline: | |
| st.write(f"Guideline: {dx.applicable_guideline}") | |
| st.divider() | |
| # --- 6. EXISTING DOWNLOAD AND CHAT LOGIC --- | |
| with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: | |
| generate_pdf_report(assessment, st.session_state.user_profile, f.name) | |
| f.seek(0) | |
| pdf_data = f.read() | |
| st.download_button("Download PDF", pdf_data, file_name="assessment.pdf") | |
| os.unlink(f.name) | |
| with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f: | |
| generate_excel_report(assessment, st.session_state.user_profile, f.name) | |
| f.seek(0) | |
| xls_data = f.read() | |
| st.download_button("Download Excel", xls_data, file_name="assessment.xlsx") | |
| # for q, a in manager.history: | |
| # st.chat_message("user").write(q) | |
| # st.chat_message("assistant").write(a) | |
| if question := st.chat_input("Ask a follow-up question"): | |
| with st.spinner("Thinking..."): | |
| resp = manager.follow_up(question) | |
| st.chat_message("user").write(question) | |
| st.chat_message("assistant").write(resp.response) | |