File size: 10,433 Bytes
8018595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc034ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8018595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""Streamlit page: Assessment."""

import os
import tempfile
from pathlib import Path

import streamlit as st
import yaml

# Configure page layout to be wider
st.set_page_config(layout="wide")
from collections import Counter

import pandas as pd
import plotly.graph_objects as go
from ui_utils import initialize_session_state

from sentinel.config import AppConfig, ModelConfig, ResourcePaths
from sentinel.conversation import ConversationManager
from sentinel.factory import SentinelFactory
from sentinel.reporting import generate_excel_report, generate_pdf_report

initialize_session_state()

if st.session_state.user_profile is None:
    st.warning(
        "Please complete your profile on the Profile page before running an assessment."
    )
    st.stop()


def create_conversation_manager(config: dict) -> ConversationManager:
    """Create a conversation manager from the current configuration.

    Args:
        config: A dictionary containing the current configuration.

    Returns:
        ConversationManager: A conversation manager instance.
    """
    # Define base paths relative to project root
    root = Path(__file__).resolve().parents[3]

    # Load model config to get provider and model name
    model_config_path = root / "configs" / "model" / f"{config['model']}.yaml"
    with open(model_config_path) as f:
        model_data = yaml.safe_load(f)

    # Create knowledge base paths
    knowledge_base_paths = ResourcePaths(
        persona=root / "prompts" / "persona" / "default.md",
        instruction_assessment=root / "prompts" / "instruction" / "assessment.md",
        instruction_conversation=root / "prompts" / "instruction" / "conversation.md",
        output_format_assessment=root / "configs" / "output_format" / "assessment.yaml",
        output_format_conversation=root
        / "configs"
        / "output_format"
        / "conversation.yaml",
        cancer_modules_dir=root / "configs" / "knowledge_base" / "cancer_modules",
        dx_protocols_dir=root / "configs" / "knowledge_base" / "dx_protocols",
    )

    # Create app config
    app_config = AppConfig(
        model=ModelConfig(
            provider=model_data["provider"], model_name=model_data["model_name"]
        ),
        knowledge_base_paths=knowledge_base_paths,
        selected_cancer_modules=config.get("cancer_modules", []),
        selected_dx_protocols=config.get("dx_protocols", []),
    )

    # Create factory and conversation manager
    factory = SentinelFactory(app_config)
    return factory.create_conversation_manager()


manager = create_conversation_manager(st.session_state.config)
st.session_state.conversation_manager = manager

st.title("🔬 Assessment")

if st.button("Run Assessment", type="primary"):
    with st.spinner("Running..."):
        result = manager.initial_assessment(st.session_state.user_profile)
        st.session_state.assessment = result

assessment = st.session_state.get("assessment")

if assessment:
    # --- 1. PRE-SORT DATA ---
    sorted_risk_assessments = sorted(
        assessment.risk_assessments, key=lambda x: x.risk_level or 0, reverse=True
    )
    sorted_dx_recommendations = sorted(
        assessment.dx_recommendations,
        key=lambda x: x.recommendation_level or 0,
        reverse=True,
    )

    # --- 2. ROW 1: OVERALL RISK SCORE ---
    st.subheader("Overall Risk Score")
    if assessment.overall_risk_score is not None:
        fig = go.Figure(
            go.Indicator(
                mode="gauge+number",
                value=assessment.overall_risk_score,
                title={"text": "Overall Score"},
                gauge={"axis": {"range": [0, 100]}},
            )
        )
        fig.update_layout(height=300, margin=dict(t=50, b=40, l=40, r=40))
        st.plotly_chart(fig, use_container_width=True)
    st.divider()

    # --- 3. ROW 2: RISK & RECOMMENDATION CHARTS ---
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Cancer Risk Levels")
        if sorted_risk_assessments:
            cancers = [ra.cancer_type for ra in sorted_risk_assessments]
            levels = [ra.risk_level or 0 for ra in sorted_risk_assessments]
            short_cancers = [c[:28] + "..." if len(c) > 28 else c for c in cancers]
            fig = go.Figure(
                go.Bar(
                    x=levels,
                    y=short_cancers,
                    orientation="h",
                    hovertext=cancers,
                    hovertemplate="<b>%{hovertext}</b><br>Risk Level: %{x}<extra></extra>",
                )
            )
            fig.update_layout(
                xaxis=dict(range=[0, 5], title="Risk Level"),
                yaxis=dict(autorange="reversed"),
                margin=dict(t=20, b=40, l=40, r=40),
            )
            st.plotly_chart(fig, use_container_width=True)

    with col2:
        st.subheader("Dx Recommendations")
        if sorted_dx_recommendations:
            tests = [dx.test_name for dx in sorted_dx_recommendations]
            recs = [dx.recommendation_level or 0 for dx in sorted_dx_recommendations]
            short_tests = [t[:28] + "..." if len(t) > 28 else t for t in tests]
            fig = go.Figure(
                go.Bar(
                    x=recs,
                    y=short_tests,
                    orientation="h",
                    hovertext=tests,
                    hovertemplate="<b>%{hovertext}</b><br>Recommendation: %{x}<extra></extra>",
                )
            )
            fig.update_layout(
                xaxis=dict(range=[0, 5], title="Recommendation"),
                yaxis=dict(autorange="reversed"),
                margin=dict(t=20, b=40, l=40, r=40),
            )
            st.plotly_chart(fig, use_container_width=True)
    st.divider()

    # --- 4. ROW 3: RISK FACTOR VISUALIZATIONS ---
    if assessment.identified_risk_factors:
        col3, col4 = st.columns(2)
        with col3:
            st.subheader("Risk Factor Summary")
            categories = [
                rf.category.value for rf in assessment.identified_risk_factors
            ]
            category_counts = Counter(categories)
            pie_fig = go.Figure(
                go.Pie(
                    labels=list(category_counts.keys()),
                    values=list(category_counts.values()),
                    hole=0.3,
                )
            )
            pie_fig.update_layout(
                height=400,
                margin=dict(t=20, b=40, l=40, r=40),
                legend=dict(
                    orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.05
                ),
            )
            st.plotly_chart(pie_fig, use_container_width=True)

        with col4:
            st.subheader("Identified Risk Factors")
            risk_factor_data = [
                {"Category": rf.category.value, "Description": rf.description}
                for rf in assessment.identified_risk_factors
            ]
            rf_df = pd.DataFrame(risk_factor_data)
            st.dataframe(rf_df, use_container_width=True, height=400, hide_index=True)

    # --- 5. EXPANDERS (using sorted data) ---
    with st.expander("Overall Summary"):
        st.markdown(assessment.overall_summary, unsafe_allow_html=True)

    with st.expander("Calculated Risk Scores (Ground Truth)"):
        if assessment.calculated_risk_scores:
            st.info(
                "These scores have been calculated using validated clinical risk models "
                "and represent the authoritative risk assessment."
            )
            for cancer_type, scores in sorted(
                assessment.calculated_risk_scores.items()
            ):
                st.markdown(f"### {cancer_type}")
                for score in scores:
                    st.markdown(f"**{score.name}**: {score.score}")
                    if score.description:
                        st.write(f"*{score.description}*")
                    if score.interpretation:
                        st.write(score.interpretation)
                    if score.references:
                        with st.expander("References"):
                            for ref in score.references:
                                st.write(f"- {ref}")
                    st.divider()
        else:
            st.write("No risk scores calculated.")

    with st.expander("AI-Generated Risk Interpretations"):
        for ra in sorted_risk_assessments:
            st.markdown(f"**{ra.cancer_type}** - {ra.risk_level or 'N/A'}/5")
            st.write(ra.explanation)
            if ra.recommended_steps:
                st.write("**Recommended Steps:**")
                steps = ra.recommended_steps
                if isinstance(steps, list):
                    for step in steps:
                        st.write(f"- {step}")
                else:
                    st.write(f"- {steps}")
            if ra.lifestyle_advice:
                st.write(f"*{ra.lifestyle_advice}*")
            st.divider()

    with st.expander("Dx Recommendations"):
        for dx in sorted_dx_recommendations:
            st.markdown(f"**{dx.test_name}** - {dx.recommendation_level or 'N/A'}/5")
            if dx.frequency:
                st.write(f"Frequency: {dx.frequency}")
            st.write(dx.rationale)
            if dx.applicable_guideline:
                st.write(f"Guideline: {dx.applicable_guideline}")
            st.divider()

    # --- 6. EXISTING DOWNLOAD AND CHAT LOGIC ---
    with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
        generate_pdf_report(assessment, st.session_state.user_profile, f.name)
        f.seek(0)
        pdf_data = f.read()
    st.download_button("Download PDF", pdf_data, file_name="assessment.pdf")
    os.unlink(f.name)

    with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f:
        generate_excel_report(assessment, st.session_state.user_profile, f.name)
        f.seek(0)
        xls_data = f.read()
    st.download_button("Download Excel", xls_data, file_name="assessment.xlsx")

    # for q, a in manager.history:
    #     st.chat_message("user").write(q)
    #     st.chat_message("assistant").write(a)

    if question := st.chat_input("Ask a follow-up question"):
        with st.spinner("Thinking..."):
            resp = manager.follow_up(question)
        st.chat_message("user").write(question)
        st.chat_message("assistant").write(resp.response)