llm_topic_modelling / test /mock_llm_calls.py
seanpedrickcase's picture
Sync: A few more refinements on xlsx formatting
23ffdb7
#!/usr/bin/env python3
"""
Mock LLM function calls for testing CLI topic extraction without API costs.
This module patches requests.post to intercept HTTP calls to inference servers
and return mock responses instead.
"""
import json
import os
# Store original requests if it exists
_original_requests = None
def _generate_mock_response(prompt: str, system_prompt: str) -> str:
"""
Generate a mock response that satisfies validation requirements.
The response must:
- Be longer than 120 characters
- Contain a markdown table (with | characters)
Args:
prompt: The user prompt
system_prompt: The system prompt
Returns:
A mock markdown table response
"""
# Generate a simple markdown table that satisfies the validation
# This mimics a topic extraction table response
mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
|-----------|---------------|-----------|-----------|
| 1 | Test Topic | Test Subtopic | Positive |
| 2 | Another Topic | Another Subtopic | Neutral |
| 3 | Third Topic | Third Subtopic | Negative |
This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
return mock_table
def _estimate_tokens(text: str) -> int:
"""Estimate token count (rough approximation: ~4 characters per token)."""
return max(1, len(text) // 4)
def mock_requests_post(url, **kwargs):
"""
Mock version of requests.post that intercepts inference-server calls.
Returns a mock response object that mimics the real requests.Response.
"""
# Only mock inference-server URLs
if "/v1/chat/completions" not in url:
# For non-inference-server URLs, use real requests
import requests
return requests.post(url, **kwargs)
# Extract payload
payload = kwargs.get("json", {})
messages = payload.get("messages", [])
# Extract prompts
system_prompt = ""
user_prompt = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
system_prompt = content
elif role == "user":
user_prompt = content
# Generate mock response
response_text = _generate_mock_response(user_prompt, system_prompt)
# Estimate tokens
input_tokens = _estimate_tokens(system_prompt + "\n" + user_prompt)
output_tokens = _estimate_tokens(response_text)
# Check if streaming is requested
stream = payload.get("stream", False)
if stream:
# For streaming, create a mock response with iter_lines
class MockStreamResponse:
def __init__(self, text):
self.text = text
self.status_code = 200
self.lines = []
# Simulate streaming chunks
chunk_size = 20
for i in range(0, len(text), chunk_size):
chunk = text[i : i + chunk_size]
chunk_data = {
"choices": [
{
"delta": {"content": chunk},
"index": 0,
"finish_reason": None,
}
]
}
self.lines.append(f"data: {json.dumps(chunk_data)}\n\n".encode())
self.lines.append(b"data: [DONE]\n\n")
self._line_index = 0
def raise_for_status(self):
pass
def iter_lines(self):
for line in self.lines:
yield line
return MockStreamResponse(response_text)
else:
# For non-streaming, create a simple mock response
class MockResponse:
def __init__(self, text, input_tokens, output_tokens):
self._json_data = {
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
},
}
],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
self.status_code = 200
def raise_for_status(self):
pass
def json(self):
return self._json_data
return MockResponse(response_text, input_tokens, output_tokens)
def apply_mock_patches():
"""
Apply patches to mock HTTP requests.
This should be called before importing modules that use requests.
"""
global _original_requests
try:
import requests
_original_requests = requests.post
requests.post = mock_requests_post
print("[Mock] Patched requests.post for inference-server calls")
except ImportError:
# requests not imported yet, will be patched when imported
pass
def restore_original():
"""Restore original requests.post if it was patched."""
global _original_requests
if _original_requests:
try:
import requests
requests.post = _original_requests
_original_requests = None
print("[Mock] Restored original requests.post")
except ImportError:
pass
# Auto-apply patches if TEST_MODE environment variable is set
if os.environ.get("TEST_MODE") == "1" or os.environ.get("USE_MOCK_LLM") == "1":
apply_mock_patches()