File size: 18,189 Bytes
41169c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e939c95
 
 
 
 
41169c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import os

# Suppress TensorFlow oneDNN optimization messages if not needed
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
# Disable ChromaDB telemetry to prevent log errors
os.environ["ANONYMIZED_TELEMETRY"] = "False"
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain.tools.retriever import create_retriever_tool
from langchain_chroma import Chroma
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import ChatOllama
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

# Set the device for HuggingFace models
device = "cpu"

# --- Configuration Constants ---
APP_HOST = "0.0.0.0"
APP_PORT = 7860

THREAD_ID = "global_health_chat_session"  # Unique ID for the chat session

# Models and Paths
EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
CROSS_ENCODER_MODEL_NAME = "BAAI/bge-reranker-v2-m3"
LLM_MODEL_NAME = "custom-model"  # Replace with your actual model, e.g., "hf.co/phureexd/qwen3_v2_gguf:Q4_K_M"
VECTOR_DB_PATH = "/app/chroma_db" if os.path.exists("/app/chroma_db") else "chroma_db"

# LLM Parameters
LLM_TEMPERATURE = 0.7
LLM_TOP_P = 0.8
LLM_TOP_K = 20
LLM_NUM_PREDICT = 512

# Retriever Parameters
RETRIEVER_SEARCH_K = 6  # Number of documents to fetch initially
RERANKER_TOP_N = 3  # Number of documents after reranking

# --- System Prompts ---

INITIAL_SYSTEM_MESSAGE = SystemMessage(
    content="""
You are a health assistant designed to answer questions related to health, wellness, nutrition, exercise, symptoms, diseases, prevention, treatment, mental health, and medical advice. This explicitly includes general statements about feeling unwell or sick (e.g., "I'm sick", "I don't feel good"). For ANY query that falls into these categories, you MUST use the retrieve_health_info tool to fetch relevant information from the database before providing an answer. This ensures your responses are accurate and based on trusted sources. Do not answer health-related questions directly without using the tool, even if you think you know the answer.

If the query is clearly unrelated to health (e.g., general knowledge questions), you can answer directly without the tool.

**Important Guidelines:**
- If the query mentions or implies health, feeling unwell, sickness, treatment, symptoms, diseases, nutrition, exercise, mental health, or wellness, use the tool.
- Even if the query is only slightly related to health, or is a general statement about feeling unwell, use the tool to provide an informed answer.
- Always respond in the same language as the user's query.
- When in doubt, err on the side of using the tool.

**Examples:**

1. **Health-Related (Use Tool):**
   - User: "What are the symptoms of diabetes?"
   - Assistant: [Uses retrieve_health_info tool] "Common symptoms of diabetes include frequent urination, excessive thirst, and fatigue."

2. **Slightly Health-Related (Use Tool):**
   - User: "Is it okay to exercise when I have a cold?"
   - Assistant: [Uses retrieve_health_info tool] "Light exercise might be okay, but rest if you have a fever."

3. **General Sickness Statement (Use Tool):**
   - User: "I'm sick."
   - Assistant: [Uses retrieve_health_info tool] "I'm sorry to hear you're not feeling well. Common advice includes resting and staying hydrated. If you have specific symptoms, I can try to provide more information."

4. **Non-Health-Related (No Tool):**
   - User: "What is the capital of France?"
   - Assistant: "The capital of France is Paris."

5. **Health-Related in Thai (Use Tool):**
   - User: "อาการของโรคเบาหวานมีอะไรบ้าง?"
   - Assistant: [Uses retrieve_health_info tool] "อาการทั่วไปของโรคเบาหวาน ได้แก่ ปัสสาวะบ่อย กระหายน้ำมาก และอ่อนเพลีย"

6. **Non-Health-Related in Thai (No Tool):**
   - User: "เมืองหลวงของฝรั่งเศสคืออะไร?"
   - Assistant: "เมืองหลวงของฝรั่งเศสคือปารีส"
/no_think                                       
"""
)

RAG_SYSTEM_PROMPT_TEMPLATE = """
You are a health assistant for question-answering tasks.
Use the following pieces of retrieved documents to answer the question.
If you don't know the answer, say that you don't know.
Keep the answer concise and accurate.

**Extremely important: Answer in the same language as the user query.**

### Retrieved documents (if applicable):
{docs_content}

### Examples of the language model's responses:
**Example 1 (English):**
User: I feel a bit tired, what could it be?
Assistant: Fatigue can be caused by lack of sleep, stress, or dehydration. Ensure you get 7-8 hours of sleep and stay hydrated.

**Example 2 (English):**
User: Does coffee affect my health?
Assistant: Moderate coffee consumption can improve alertness but may cause insomnia or anxiety if overconsumed.

**Example 3 (Thai):**
User: ฉันรู้สึกเหนื่อยเล็กน้อย เกิดจากอะไรได้บ้าง?
Assistant: อาการเหนื่อยอาจเกิดจากการนอนหลับไม่เพียงพอ ความเครียด หรือภาวะขาดน้ำ ควรนอนหลับ 7-8 ชั่วโมงและดื่มน้ำให้เพียงพอ
/no_think
"""

# --- Initialization of Langchain Components ---


def init_embeddings(model_name: str):
    """Initializes HuggingFace embeddings."""
    return HuggingFaceEmbeddings(model_name=model_name)


def init_vector_store(embedding_function, persist_directory: str):
    """Initializes Chroma vector store."""
    return Chroma(
        embedding_function=embedding_function,
        persist_directory=persist_directory,
    )


def init_llm(
    model_name: str, temperature: float, top_p: float, top_k: int, num_predict: int
):
    """Initializes ChatOllama LLM."""
    return ChatOllama(
        model=model_name,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_predict=num_predict,
    )


def init_retriever_tool(
    vector_store_instance,
    cross_encoder_model_name: str,
    base_retriever_k: int,
    reranker_top_n: int,
):
    """Initializes the retriever tool with reranking."""
    base_retriever = vector_store_instance.as_retriever(
        search_kwargs={"k": base_retriever_k}
    )

    cross_encoder = HuggingFaceCrossEncoder(
        model_name=cross_encoder_model_name,
        model_kwargs={"device": device},  # Specify device if needed, e.g., "cuda"
    )
    reranker = CrossEncoderReranker(model=cross_encoder, top_n=reranker_top_n)

    compression_retriever = ContextualCompressionRetriever(
        base_compressor=reranker,
        base_retriever=base_retriever,
    )

    return create_retriever_tool(
        retriever=compression_retriever,
        name="retrieve_health_info",
        description=(
            "Use this tool to retrieve relevant documents from the query related to health, "
            "wellness, nutrition, exercise, symptoms, diseases, treatment, prevention, "
            "mental health, or medical advice information from the database. "
            "Even if the query is slightly related. "
            f"Return the top {reranker_top_n} most relevant documents."
        ),
        response_format="content_and_artifact",  # Ensures artifact contains Document objects
    )


# Initialize components
print("Initializing Embeddings...")
embeddings = init_embeddings(EMBEDDING_MODEL_NAME)
print("Embeddings Initialized.")

print("Initializing Vector Store...")
vector_store = init_vector_store(embeddings, VECTOR_DB_PATH)
print("Vector Store Initialized.")

print("Initializing LLM...")
llm = init_llm(LLM_MODEL_NAME, LLM_TEMPERATURE, LLM_TOP_P, LLM_TOP_K, LLM_NUM_PREDICT)
print("LLM Initialized.")

print("Initializing Retriever Tool...")
retriever_tool = init_retriever_tool(
    vector_store, CROSS_ENCODER_MODEL_NAME, RETRIEVER_SEARCH_K, RERANKER_TOP_N
)
print("Retriever Tool Initialized.")

# --- LangGraph Node Definitions ---


async def query_or_respond_node_logic(state: MessagesState):
    """
    Node function: Decides whether to call a tool for retrieval or respond directly.
    Binds the retriever_tool to the LLM for this decision.
    """
    response = await llm.bind_tools([retriever_tool]).ainvoke(state["messages"])
    return {"messages": [response]}


async def generate_rag_response_node_logic(state: MessagesState):
    """
    Node function: Generates a response using retrieved documents (if any).
    """
    # Extract the most recent contiguous block of tool messages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":  # or isinstance(message, ToolMessage)
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format retrieved document content for the prompt
    doc_strings = []
    for tool_msg in tool_messages:
        # Ensure artifact is a list of Langchain Document objects
        if hasattr(tool_msg, "artifact") and isinstance(tool_msg.artifact, list):
            for doc in tool_msg.artifact:
                if hasattr(doc, "page_content") and hasattr(
                    doc, "metadata"
                ):  # Document structure check
                    source = doc.metadata.get("source", "Unknown source")
                    content = doc.page_content
                    doc_strings.append(f"Source: {source}\nContent: {content}")

    docs_content = (
        "\n\n".join(doc_strings)
        if doc_strings
        else "No relevant documents were found to answer the current question."
    )

    # Prepare messages for the generation LLM call (history + new system prompt with docs)
    # Include human messages, initial system messages, and AI responses (not tool calls)
    conversation_history_for_llm = [
        msg
        for msg in state["messages"]
        if msg.type in ("human", "system") or (msg.type == "ai" and not msg.tool_calls)
    ]

    # Construct the system prompt with retrieved documents
    current_system_prompt_content = RAG_SYSTEM_PROMPT_TEMPLATE.format(
        docs_content=docs_content
    )

    prompt_for_generation = [
        SystemMessage(content=current_system_prompt_content)
    ] + conversation_history_for_llm

    response = await llm.ainvoke(prompt_for_generation)
    return {"messages": [response]}


# --- LangGraph Graph Construction ---


def create_lang_graph(checkpointer_instance):
    """Creates and compiles the LangGraph."""
    graph_builder = StateGraph(MessagesState)

    # Define nodes
    graph_builder.add_node("query_or_respond", query_or_respond_node_logic)
    tools_node = ToolNode([retriever_tool])  # Define tool execution node
    graph_builder.add_node("tools", tools_node)
    graph_builder.add_node("generate_rag_response", generate_rag_response_node_logic)

    # Define edges
    graph_builder.set_entry_point("query_or_respond")
    graph_builder.add_conditional_edges(
        "query_or_respond",
        tools_condition,  # Prebuilt condition to check for tool calls
        {END: END, "tools": "tools"},
    )
    graph_builder.add_edge("tools", "generate_rag_response")
    graph_builder.add_edge("generate_rag_response", END)

    return graph_builder.compile(checkpointer=checkpointer_instance)


# Initialize checkpointer and compile graph
memory_saver = MemorySaver()
graph = create_lang_graph(memory_saver)

# Optional: Save graph visualization
# try:
#     graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
#     print("Graph visualization saved to graph.png")
# except Exception as e:
#     print(f"Could not save graph visualization: {e}")


# --- FastAPI Application Setup ---
app = FastAPI(
    title="Health Assistant API",
    description="API for a health assistant using a retrieval-augmented generation approach.",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins for simplicity; restrict in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- FastAPI Endpoints ---


@app.get("/", summary="Root endpoint")
async def root():
    return {"message": "Health Assistant API is running!", "docs_url": "/docs"}


@app.get("/generate", summary="Generate a response from the health assistant model")
async def generate_endpoint(
    query: str,
    useRAG: bool = False,
    thread_id: str = THREAD_ID,
):
    """
    Handles a user query, streams back the assistant's responses.
    - `query`: The user's question.
    - `useRAG`: If true, forces the use of the retrieval tool via a system message.
    - `thread_id`: Unique identifier for the conversation session.
    """
    print(f"Received query: '{query}', Force RAG: {useRAG}, Thread ID: {thread_id}")

    config = {"configurable": {"thread_id": thread_id}}

    # Prepare input messages for the graph
    input_messages = []
    current_checkpoint_tuple = memory_saver.get_tuple(config)  # Check if history exists

    if current_checkpoint_tuple is None:  # No history, it's a new or cleared session
        input_messages.append(INITIAL_SYSTEM_MESSAGE)
        print("Starting new conversation: Added initial system message.")

    if useRAG:
        # This message is added to strongly encourage tool use for the current query,
        # supplementing the INITIAL_SYSTEM_MESSAGE.
        input_messages.append(
            SystemMessage(
                content="You MUST use the retrieve_health_info tool for this query even if it seems unrelated to health."
            )
        )
        print("Forcing RAG for this query with an additional system message.")

    input_messages.append(HumanMessage(content=query))
    graph_input = {"messages": input_messages}

    async def stream_response_events():
        # graph.stream with stream_mode="messages" yields the ENTIRE list of messages
        # in the current state each time a node completes.
        async for messages_in_state in graph.astream(
            graph_input, config, stream_mode="messages"
        ):
            if not messages_in_state:
                continue

            # Get the current message from the state
            latest_message = messages_in_state[0]

            if isinstance(latest_message, AIMessage):
                if latest_message.content:  # Final textual response
                    # print(
                    #     f"Streaming AI content: {latest_message.content}"
                    # )
                    yield f"data: {latest_message.content}\n\n"
                elif latest_message.tool_calls:  # AI message requesting a tool call
                    print(f"AI requested Tool call: {latest_message.tool_calls}")
                    # You might want to send a status to the client, e.g., "Thinking..." or "Retrieving info..."
                    # yield f"event: tool_call\ndata: {json.dumps(latest_message.tool_calls)}\n\n"
            elif isinstance(
                latest_message, ToolMessage
            ):  # Message containing tool execution results
                if latest_message.name == "retrieve_health_info" and hasattr(
                    latest_message, "artifact"
                ):
                    print(f"Tool '{latest_message.name}' executed. Artifact content:")
                    if latest_message.artifact and isinstance(
                        latest_message.artifact, list
                    ):
                        # print every document in the artifact
                        source_list = set()
                        for doc in latest_message.artifact:
                            source = doc.metadata.get("source", "Unknown source")

                            if source != "Unknown source":
                                source_list.add(source)

                            print(f"  Source: {source}\n   Content: {doc.page_content}")
                    yield f"data: **Source:**{str(source_list)}\n\n"

    return StreamingResponse(
        stream_response_events(),
        media_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
    )


@app.get("/clear", summary="Clear conversation history")
async def clear_conversation_endpoint(thread_id: str = THREAD_ID):
    """Clears the conversation history for the specified thread_id."""
    try:
        # Note: MemorySaver in some versions might not support explicit deletion easily via public API
        # This is a best-effort attempt or placeholder for actual persistence deletion
        # If using a real DB checkpointer, you would delete rows here.
        # For MemorySaver, we might just need to reset the state or let it be if it's per-request instance (it's not here).
        # Actually, MemorySaver stores in a dict. We can try accessing it if we really need to clear.
        if hasattr(memory_saver, "storage"):
             if thread_id in memory_saver.storage:
                 del memory_saver.storage[thread_id]

        print(f"Conversation history cleared for thread_id: {thread_id}")
        return {"status": "success", "message": "Conversation history cleared."}
    except Exception as e:
        print(f"Error clearing conversation history for thread_id {thread_id}: {e}")
        return {"status": "error", "message": f"Failed to clear history: {e}"}


# --- Main Execution ---
if __name__ == "__main__":
    print(f"Starting Health Assistant API on {APP_HOST}:{APP_PORT}")
    uvicorn.run(app, host=APP_HOST, port=APP_PORT)