ratyim commited on
Commit
a422356
·
verified ·
1 Parent(s): 7c1a0cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1544 -0
app.py ADDED
@@ -0,0 +1,1544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Llama + FAISS RAG System for Fire Evacuation with Advanced Reasoning
3
+
4
+ This module implements a RAG (Retrieval-Augmented Generation) system for fire evacuation scenarios
5
+ with advanced LLM reasoning techniques including:
6
+
7
+ 1. Chain-of-Thought (CoT) Prompting:
8
+ - Enables step-by-step reasoning through intermediate steps
9
+ - Improves complex problem-solving capabilities
10
+ - Reference: https://arxiv.org/pdf/2201.11903
11
+
12
+ 2. Tree-of-Thoughts (ToT):
13
+ - Maintains multiple reasoning paths
14
+ - Self-evaluates progress through intermediate thoughts
15
+ - Enables deliberate reasoning process
16
+ - Reference: https://arxiv.org/pdf/2305.10601
17
+
18
+ 3. Reflexion:
19
+ - Reinforces language-based agents through linguistic feedback
20
+ - Self-reflection and iterative improvement
21
+ - Reference: https://arxiv.org/pdf/2303.11366
22
+
23
+ 4. CoT with Tools:
24
+ - Combines CoT prompting with external tools
25
+ - Interleaved reasoning and tool usage
26
+ - Reference: https://arxiv.org/pdf/2303.09014
27
+
28
+ 5. Advanced Decoding Strategies:
29
+ - Greedy: Deterministic highest probability
30
+ - Sampling: Random sampling with temperature
31
+ - Beam Search: Explores multiple paths
32
+ - Nucleus (Top-p): Samples from top-p probability mass
33
+ - Temperature: Temperature-based sampling
34
+
35
+ Downloads Llama model, creates JSON dataset, builds FAISS index, and provides RAG querying
36
+ """
37
+ import unsloth
38
+ import json
39
+ import os
40
+ import pickle
41
+ import glob
42
+ import re
43
+ from typing import List, Dict, Any, Optional, Tuple
44
+
45
+ from pathlib import Path
46
+ from enum import Enum
47
+ import copy
48
+
49
+ import numpy as np
50
+ import faiss
51
+ import torch
52
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
53
+ from sentence_transformers import SentenceTransformer
54
+ import gradio as gr
55
+
56
+ # Project imports (use helper_files package)
57
+ from helper_files.floor_plan import create_sample_floor_plan, FloorPlan
58
+ from helper_files.sensor_system import create_sample_fire_scenario, SensorSystem
59
+ from helper_files.pathfinding import PathFinder
60
+
61
+
62
+
63
+ class FireEvacuationDataExporter:
64
+ """Exports fire evacuation system data to JSON format"""
65
+
66
+ def __init__(self, floor_plan: FloorPlan, sensor_system: SensorSystem, pathfinder: PathFinder):
67
+ self.floor_plan = floor_plan
68
+ self.sensor_system = sensor_system
69
+ self.pathfinder = pathfinder
70
+
71
+ def export_room_data(self, room_id: str) -> Dict[str, Any]:
72
+ """Export comprehensive room data to JSON"""
73
+ room = self.floor_plan.get_room(room_id)
74
+ sensor = self.sensor_system.get_sensor_reading(room_id)
75
+
76
+ if not room or not sensor:
77
+ return {}
78
+
79
+ return {
80
+ "room_id": room_id,
81
+ "name": room.name,
82
+ "room_type": room.room_type,
83
+ "position": room.position,
84
+ "size": room.size,
85
+ "has_oxygen_cylinder": room.has_oxygen_cylinder,
86
+ "has_fire_extinguisher": room.has_fire_extinguisher,
87
+ "connected_to": [conn[0] for conn in room.connected_to],
88
+ "sensor_data": {
89
+ "fire_detected": sensor.fire_detected,
90
+ "smoke_level": round(sensor.smoke_level, 2),
91
+ "temperature_c": round(sensor.temperature, 1),
92
+ "oxygen_pct": round(sensor.oxygen_level, 1),
93
+ "visibility_pct": round(sensor.visibility, 1),
94
+ "structural_integrity_pct": round(sensor.structural_integrity, 1),
95
+ "fire_growth_rate": round(sensor.fire_growth_rate, 2),
96
+ "flashover_risk": round(sensor.flashover_risk, 2),
97
+ "backdraft_risk": round(sensor.backdraft_risk, 2),
98
+ "heat_radiation": round(sensor.heat_radiation, 2),
99
+ "fire_type": sensor.fire_type,
100
+ "carbon_monoxide_ppm": round(sensor.carbon_monoxide, 1),
101
+ "carbon_dioxide_ppm": round(sensor.carbon_dioxide, 1),
102
+ "hydrogen_cyanide_ppm": round(sensor.hydrogen_cyanide, 2),
103
+ "hydrogen_chloride_ppm": round(sensor.hydrogen_chloride, 2),
104
+ "wind_direction": round(sensor.wind_direction, 1),
105
+ "wind_speed": round(sensor.wind_speed, 2),
106
+ "air_pressure": round(sensor.air_pressure, 2),
107
+ "humidity": round(sensor.humidity, 1),
108
+ "occupancy_density": round(sensor.occupancy_density, 2),
109
+ "mobility_limitations": sensor.mobility_limitations,
110
+ "panic_level": round(sensor.panic_level, 2),
111
+ "evacuation_progress": round(sensor.evacuation_progress, 1),
112
+ "sprinkler_active": sensor.sprinkler_active,
113
+ "emergency_lighting": sensor.emergency_lighting,
114
+ "elevator_available": sensor.elevator_available,
115
+ "stairwell_clear": sensor.stairwell_clear,
116
+ "exit_accessible": sensor.exit_accessible,
117
+ "exit_capacity": sensor.exit_capacity,
118
+ "ventilation_active": sensor.ventilation_active,
119
+ "time_since_fire_start": sensor.time_since_fire_start,
120
+ "estimated_time_to_exit": sensor.estimated_time_to_exit,
121
+ "emergency_comm_working": sensor.emergency_comm_working,
122
+ "wifi_signal_strength": round(sensor.wifi_signal_strength, 1),
123
+ "danger_score": round(sensor.calculate_danger_score(), 1),
124
+ "passable": sensor.is_passable()
125
+ }
126
+ }
127
+
128
+ def export_route_data(self, start_location: str = "R1") -> Dict[str, Any]:
129
+ """Export all evacuation routes with detailed information"""
130
+ routes = self.pathfinder.find_all_evacuation_routes(start_location)
131
+
132
+ route_data = {
133
+ "timestamp_sec": 0,
134
+ "start_location": start_location,
135
+ "total_routes": len(routes),
136
+ "routes": []
137
+ }
138
+
139
+ for idx, (exit_id, path, risk) in enumerate(routes, 1):
140
+ route_info = {
141
+ "route_id": f"Route {idx}",
142
+ "exit": exit_id,
143
+ "path": path,
144
+ "metrics": {
145
+ "avg_danger": round(risk['avg_danger'], 2),
146
+ "max_danger": round(risk['max_danger'], 2),
147
+ "max_danger_location": risk['max_danger_location'],
148
+ "total_danger": round(risk['total_danger'], 2),
149
+ "path_length": risk['path_length'],
150
+ "has_fire": risk['has_fire'],
151
+ "has_oxygen_hazard": risk['has_oxygen_hazard'],
152
+ "passable": risk['passable'],
153
+ "risk_factors": risk['risk_factors']
154
+ },
155
+ "nodes": []
156
+ }
157
+
158
+ # Add detailed node information
159
+ for room_id in path:
160
+ node_data = self.export_room_data(room_id)
161
+ if node_data:
162
+ route_info["nodes"].append(node_data)
163
+
164
+ route_data["routes"].append(route_info)
165
+
166
+ return route_data
167
+
168
+ def export_all_rooms(self) -> List[Dict[str, Any]]:
169
+ """Export all rooms as separate documents"""
170
+ all_rooms = []
171
+ for room_id in self.floor_plan.rooms:
172
+ room_data = self.export_room_data(room_id)
173
+ if room_data:
174
+ all_rooms.append(room_data)
175
+ return all_rooms
176
+
177
+ def export_to_json(self, output_path: str, start_location: str = "R1"):
178
+ """Export complete dataset to JSON file"""
179
+ data = {
180
+ "floor_plan": {
181
+ "floor_name": self.floor_plan.floor_name,
182
+ "total_rooms": len(self.floor_plan.rooms),
183
+ "exits": self.floor_plan.exits
184
+ },
185
+ "all_rooms": self.export_all_rooms(),
186
+ "evacuation_routes": self.export_route_data(start_location)
187
+ }
188
+
189
+ with open(output_path, 'w', encoding='utf-8') as f:
190
+ json.dump(data, f, indent=2, ensure_ascii=False)
191
+
192
+ print(f"[OK] Exported data to {output_path}")
193
+ return data
194
+
195
+
196
+ class ReasoningMode(Enum):
197
+ """Enumeration of reasoning modes"""
198
+ STANDARD = "standard"
199
+ CHAIN_OF_THOUGHT = "chain_of_thought"
200
+ TREE_OF_THOUGHTS = "tree_of_thoughts"
201
+ REFLEXION = "reflexion"
202
+ COT_WITH_TOOLS = "cot_with_tools"
203
+
204
+
205
+ class DecodingStrategy(Enum):
206
+ """Enumeration of decoding strategies"""
207
+ GREEDY = "greedy"
208
+ SAMPLING = "sampling"
209
+ BEAM_SEARCH = "beam_search"
210
+ NUCLEUS = "nucleus"
211
+ TEMPERATURE = "temperature"
212
+
213
+
214
+ class FireEvacuationRAG:
215
+ """RAG system using FAISS for retrieval and Llama for generation with advanced reasoning"""
216
+
217
+ def __init__(self, model_name: str = "nvidia/Llama-3.1-Minitron-4B-Width-Base", model_dir: str = "./models",
218
+ use_8bit: bool = False, use_unsloth: bool = False, load_in_4bit: bool = True, max_seq_length: int = 2048,
219
+ reasoning_mode: ReasoningMode = ReasoningMode.CHAIN_OF_THOUGHT,
220
+ decoding_strategy: DecodingStrategy = DecodingStrategy.NUCLEUS):
221
+ self.model_name = model_name
222
+ self.model_dir = model_dir
223
+ self.local_model_path = os.path.join(model_dir, model_name.replace("/", "_"))
224
+ self.use_8bit = use_8bit
225
+ self.use_unsloth = use_unsloth
226
+ self.load_in_4bit = load_in_4bit
227
+ self.max_seq_length = max_seq_length
228
+ self.reasoning_mode = reasoning_mode
229
+ self.decoding_strategy = decoding_strategy
230
+ self.tokenizer = None
231
+ self.model = None
232
+ self.pipe = None
233
+ self.embedder = None
234
+ self.index = None
235
+ self.documents = []
236
+ self.metadata = []
237
+ self.reflexion_history = [] # Store reflection history for Reflexion
238
+
239
+ # Create model directory if it doesn't exist
240
+ os.makedirs(self.model_dir, exist_ok=True)
241
+ os.makedirs(self.local_model_path, exist_ok=True)
242
+
243
+ print(f"Initializing RAG system with model: {model_name}")
244
+ print(f"Model will be saved to: {self.local_model_path}")
245
+ print(f"Reasoning mode: {reasoning_mode.value}")
246
+ print(f"Decoding strategy: {decoding_strategy.value}")
247
+ if use_unsloth:
248
+ print("[*] Unsloth enabled (faster loading and inference)")
249
+ if load_in_4bit:
250
+ print(" - 4-bit quantization enabled (very fast, low memory)")
251
+ elif use_8bit:
252
+ print("[!] 8-bit quantization enabled (faster loading, lower memory, slight quality trade-off)")
253
+
254
+ def _check_model_files_exist(self, model_path: str) -> bool:
255
+ """Check if model files actually exist (not just config.json)"""
256
+ required_files = [
257
+ "config.json",
258
+ "model.safetensors.index.json" # Check for sharded model index
259
+ ]
260
+
261
+ # Check for at least one model file
262
+ model_file_patterns = [
263
+ "model.safetensors",
264
+ "pytorch_model.bin",
265
+ "model-*.safetensors" # Sharded models
266
+ ]
267
+
268
+ config_exists = os.path.exists(os.path.join(model_path, "config.json"))
269
+ if not config_exists:
270
+ return False
271
+
272
+ # Check for model weight files
273
+ for pattern in model_file_patterns:
274
+ if glob.glob(os.path.join(model_path, pattern)):
275
+ return True
276
+
277
+ # Check for sharded model index
278
+ if os.path.exists(os.path.join(model_path, "model.safetensors.index.json")):
279
+ return True
280
+
281
+ return False
282
+
283
+ def download_model(self):
284
+ """Download and load the Llama model, saving weights to local directory"""
285
+ print("Downloading Llama model (this may take a while)...")
286
+ print(f"Model weights will be saved to: {self.local_model_path}")
287
+
288
+ # Use Unsloth if enabled (much faster loading) - PRIMARY METHOD
289
+ if self.use_unsloth:
290
+ try:
291
+ from unsloth import FastLanguageModel
292
+ from transformers import TextStreamer
293
+ print("[*] Using Unsloth for fast model loading...")
294
+
295
+ # Check if model name indicates it's already quantized (contains "bnb-4bit" or "bnb-8bit")
296
+ is_pre_quantized = "bnb-4bit" in self.model_name.lower() or "bnb-8bit" in self.model_name.lower()
297
+
298
+ # For pre-quantized models, don't set load_in_4bit (model is already quantized)
299
+ # For non-quantized models, check if bitsandbytes is available
300
+ if self.load_in_4bit and not is_pre_quantized:
301
+ try:
302
+ import bitsandbytes
303
+ print("[OK] bitsandbytes available for 4-bit quantization")
304
+ except ImportError:
305
+ print("[!] bitsandbytes not found. 4-bit quantization requires bitsandbytes.")
306
+ print(" Install with: pip install bitsandbytes")
307
+ print(" Falling back to full precision...")
308
+ self.load_in_4bit = False
309
+
310
+ # Check if model exists locally
311
+ if self._check_model_files_exist(self.local_model_path):
312
+ print(f"Loading from local path: {self.local_model_path}")
313
+ model_path = self.local_model_path
314
+ else:
315
+ print(f"Downloading model: {self.model_name}")
316
+ model_path = self.model_name
317
+
318
+ # ==== Load Model with Unsloth (exact pattern from user) ====
319
+ dtype = None # Auto-detect dtype
320
+
321
+ # Try loading with proper error handling for bitsandbytes
322
+ # The model config might have quantization settings that trigger bitsandbytes check
323
+ max_retries = 2
324
+ for attempt in range(max_retries):
325
+ try:
326
+ # For pre-quantized models, don't specify load_in_4bit (it's already quantized)
327
+ if is_pre_quantized or attempt > 0:
328
+ print("[OK] Loading model without quantization parameters...")
329
+ # Don't pass any quantization parameters
330
+ load_kwargs = {
331
+ "model_name": model_path,
332
+ "max_seq_length": self.max_seq_length,
333
+ "dtype": dtype,
334
+ }
335
+ else:
336
+ # For non-quantized models, try quantization if requested
337
+ load_kwargs = {
338
+ "model_name": model_path,
339
+ "max_seq_length": self.max_seq_length,
340
+ "dtype": dtype,
341
+ }
342
+ if self.load_in_4bit:
343
+ load_kwargs["load_in_4bit"] = True
344
+
345
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(**load_kwargs)
346
+ break # Success, exit retry loop
347
+
348
+ except (ImportError, Exception) as quant_error:
349
+ error_str = str(quant_error)
350
+ is_bitsandbytes_error = (
351
+ "bitsandbytes" in error_str.lower() or
352
+ "PackageNotFoundError" in error_str or
353
+ "No package metadata" in error_str or
354
+ "quantization_config" in error_str.lower()
355
+ )
356
+
357
+ if is_bitsandbytes_error and attempt < max_retries - 1:
358
+ print(f"[!] Attempt {attempt + 1}: bitsandbytes error detected.")
359
+ print(f" Error: {error_str[:150]}...")
360
+ print(" Retrying without quantization parameters...")
361
+ continue # Retry without quantization
362
+ elif is_bitsandbytes_error:
363
+ print("[!] bitsandbytes required but not installed.")
364
+ print(" Options:")
365
+ print(" 1. Install bitsandbytes: pip install bitsandbytes")
366
+ print(" 2. Use a non-quantized model")
367
+ print(" 3. Set USE_UNSLOTH=False to use standard loading")
368
+ raise ImportError(
369
+ "bitsandbytes is required for this model. "
370
+ "Install with: pip install bitsandbytes"
371
+ ) from quant_error
372
+ else:
373
+ # Re-raise if it's a different error
374
+ raise
375
+
376
+ # Optimize for inference
377
+ FastLanguageModel.for_inference(self.model)
378
+
379
+ print("[OK] Model loaded successfully with Unsloth!")
380
+
381
+ # Verify device
382
+ if torch.cuda.is_available():
383
+ actual_device = next(self.model.parameters()).device
384
+ print(f"[OK] Model loaded on {actual_device}!")
385
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
386
+ print(f"[OK] GPU Memory allocated: {allocated:.2f} GB")
387
+ else:
388
+ print("[OK] Model loaded on CPU!")
389
+
390
+ # Set pipe to model for compatibility (we'll use model directly in generation)
391
+ self.pipe = self.model # Store model reference for compatibility checks
392
+
393
+ return # Exit early, Unsloth loading complete
394
+
395
+ except ImportError:
396
+ print("[!] Unsloth not installed. Falling back to standard loading.")
397
+ print(" Install with: pip install unsloth")
398
+ self.use_unsloth = False # Disable unsloth for this session
399
+ except Exception as e:
400
+ print(f"[!] Unsloth loading failed: {e}")
401
+ print(" Falling back to standard loading...")
402
+ self.use_unsloth = False
403
+
404
+ # Standard loading (original code)
405
+ # Check GPU availability and optimize settings
406
+ device = "cuda" if torch.cuda.is_available() else "cpu"
407
+ if device == "cuda":
408
+ gpu_name = torch.cuda.get_device_name(0)
409
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
410
+ print(f"[OK] GPU detected: {gpu_name}")
411
+ print(f"[OK] GPU Memory: {gpu_memory:.2f} GB")
412
+ # Use bfloat16 for faster loading and inference on GPU
413
+ torch_dtype = torch.bfloat16
414
+ print("[OK] Using bfloat16 precision for faster loading")
415
+ else:
416
+ print("[!] No GPU detected, using CPU (will be slower)")
417
+ torch_dtype = torch.float32
418
+ print("[OK] Using float32 precision for CPU")
419
+
420
+ # Check for optimized attention implementation
421
+ try:
422
+ import flash_attn # noqa: F401
423
+ attn_impl = 'flash_attention_2'
424
+ print("[OK] FlashAttention2 available - using for optimal performance")
425
+ except ImportError:
426
+ attn_impl = 'sdpa' # Scaled Dot Product Attention (built into PyTorch)
427
+ print("[OK] Using SDPA (Scaled Dot Product Attention) for faster inference")
428
+
429
+ # Check for 8-bit quantization support
430
+ use_quantization = False
431
+ if self.use_8bit and device == "cuda":
432
+ try:
433
+ import bitsandbytes
434
+ use_quantization = True
435
+ print("[OK] 8-bit quantization available - will use for faster loading")
436
+ except ImportError:
437
+ print("[!] 8-bit requested but bitsandbytes not installed, using full precision")
438
+ print(" Install with: pip install bitsandbytes")
439
+
440
+ try:
441
+ # Check if model already exists locally with actual model files
442
+ if self._check_model_files_exist(self.local_model_path):
443
+ print(f"Found existing model at {self.local_model_path}, loading from local...")
444
+ model_path = self.local_model_path
445
+ load_from_local = True
446
+ else:
447
+ print("Downloading model from HuggingFace...")
448
+ model_path = self.model_name
449
+ load_from_local = False
450
+
451
+ # Load tokenizer
452
+ print("Loading tokenizer...")
453
+ self.tokenizer = AutoTokenizer.from_pretrained(
454
+ model_path,
455
+ trust_remote_code=True
456
+ )
457
+
458
+ # Save tokenizer locally if downloaded (wrap in try-except to avoid crashes)
459
+ if not load_from_local:
460
+ try:
461
+ print("Saving tokenizer to local directory...")
462
+ self.tokenizer.save_pretrained(self.local_model_path)
463
+ print(f"[OK] Tokenizer saved to {self.local_model_path}")
464
+ except Exception as save_err:
465
+ print(f"[!] Warning: Could not save tokenizer locally: {save_err}")
466
+ print("Continuing without local save...")
467
+
468
+ # Load model with optimizations
469
+ print("Loading model with optimizations...")
470
+ load_kwargs = {
471
+ "trust_remote_code": True,
472
+ "low_cpu_mem_usage": True, # Reduces memory usage during loading
473
+ "_attn_implementation": attn_impl, # Optimized attention
474
+ }
475
+
476
+ # Add quantization or dtype
477
+ if use_quantization:
478
+ from transformers import BitsAndBytesConfig
479
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
480
+ load_in_8bit=True,
481
+ llm_int8_threshold=6.0
482
+ )
483
+ print("[OK] Using 8-bit quantization for faster loading and lower memory")
484
+ else:
485
+ load_kwargs["torch_dtype"] = torch_dtype
486
+
487
+ # Use device_map="auto" for GPU, manual placement for CPU
488
+ if device == "cuda":
489
+ try:
490
+ load_kwargs["device_map"] = "auto"
491
+ print("[OK] Using device_map='auto' for optimal GPU memory management")
492
+ except Exception as e:
493
+ print(f"[!] device_map='auto' failed, using manual GPU placement: {e}")
494
+ load_kwargs.pop("device_map", None)
495
+
496
+ self.model = AutoModelForCausalLM.from_pretrained(
497
+ model_path,
498
+ **load_kwargs
499
+ )
500
+
501
+ # Manual device placement if device_map wasn't used
502
+ if device == "cuda" and "device_map" not in load_kwargs:
503
+ self.model = self.model.cuda()
504
+ print("[OK] Model moved to GPU")
505
+
506
+ # Save model locally if downloaded (wrap in try-except to handle DTensor errors)
507
+ if not load_from_local:
508
+ try:
509
+ print("Saving model weights to local directory (this may take a while)...")
510
+ self.model.save_pretrained(
511
+ self.local_model_path,
512
+ safe_serialization=True # Use safetensors format
513
+ )
514
+ print(f"[OK] Model saved to {self.local_model_path}")
515
+ except ImportError as import_err:
516
+ if "DTensor" in str(import_err):
517
+ print(f"[!] Warning: Could not save model due to PyTorch/transformers compatibility issue: {import_err}")
518
+ print("This is a known issue with certain versions. Model will work but won't be saved locally.")
519
+ print("Continuing without local save...")
520
+ else:
521
+ raise
522
+ except Exception as save_err:
523
+ print(f"[!] Warning: Could not save model locally: {save_err}")
524
+ print("Continuing without local save...")
525
+
526
+ # Create pipeline with optimizations
527
+ print("Creating pipeline...")
528
+ pipeline_kwargs = {
529
+ "model": self.model,
530
+ "tokenizer": self.tokenizer,
531
+ }
532
+ if device == "cuda":
533
+ pipeline_kwargs["device_map"] = "auto"
534
+
535
+ self.pipe = pipeline("text-generation", **pipeline_kwargs)
536
+
537
+ # Verify model device
538
+ if device == "cuda":
539
+ actual_device = next(self.model.parameters()).device
540
+ print(f"[OK] Model loaded successfully on {actual_device}!")
541
+ if torch.cuda.is_available():
542
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
543
+ print(f"[OK] GPU Memory allocated: {allocated:.2f} GB")
544
+ else:
545
+ print("[OK] Model loaded successfully on CPU!")
546
+
547
+ except Exception as e:
548
+ print(f"Error loading model: {e}")
549
+ print("Falling back to pipeline-only loading...")
550
+ try:
551
+ # Determine device and dtype for fallback
552
+ device = "cuda" if torch.cuda.is_available() else "cpu"
553
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
554
+
555
+ # Try loading from local path first (only if model files actually exist)
556
+ if self._check_model_files_exist(self.local_model_path):
557
+ print(f"Attempting to load from local path: {self.local_model_path}")
558
+ pipeline_kwargs = {
559
+ "model": self.local_model_path,
560
+ "trust_remote_code": True,
561
+ "torch_dtype": torch_dtype,
562
+ }
563
+ if device == "cuda":
564
+ pipeline_kwargs["device_map"] = "auto"
565
+ self.pipe = pipeline("text-generation", **pipeline_kwargs)
566
+ # Extract tokenizer from pipeline if available
567
+ if hasattr(self.pipe, 'tokenizer'):
568
+ self.tokenizer = self.pipe.tokenizer
569
+ else:
570
+ print(f"Downloading model: {self.model_name}")
571
+ pipeline_kwargs = {
572
+ "model": self.model_name,
573
+ "trust_remote_code": True,
574
+ "torch_dtype": torch_dtype,
575
+ }
576
+ if device == "cuda":
577
+ pipeline_kwargs["device_map"] = "auto"
578
+ self.pipe = pipeline("text-generation", **pipeline_kwargs)
579
+ # Extract tokenizer from pipeline if available
580
+ if hasattr(self.pipe, 'tokenizer'):
581
+ self.tokenizer = self.pipe.tokenizer
582
+
583
+ # Try to save after loading (but don't fail if it doesn't work)
584
+ try:
585
+ if hasattr(self.pipe, 'model') and hasattr(self.pipe.model, 'save_pretrained'):
586
+ print("Attempting to save downloaded model to local directory...")
587
+ self.pipe.model.save_pretrained(self.local_model_path, safe_serialization=True)
588
+ if hasattr(self.pipe, 'tokenizer'):
589
+ self.pipe.tokenizer.save_pretrained(self.local_model_path)
590
+ print("[OK] Model saved successfully")
591
+ except ImportError as import_err:
592
+ if "DTensor" in str(import_err):
593
+ print(f"[!] Warning: Could not save model due to compatibility issue. Model will work but won't be saved locally.")
594
+ else:
595
+ print(f"[!] Warning: Could not save model: {import_err}")
596
+ except Exception as save_err:
597
+ print(f"[!] Warning: Could not save model locally: {save_err}")
598
+
599
+ except Exception as e2:
600
+ print(f"Pipeline loading also failed: {e2}")
601
+ raise
602
+
603
+ def load_embedder(self, model_name: str = "all-MiniLM-L6-v2"):
604
+ """Load sentence transformer for embeddings, saving to local directory"""
605
+ embedder_dir = os.path.join(self.model_dir, "embedder", model_name.replace("/", "_"))
606
+ os.makedirs(embedder_dir, exist_ok=True)
607
+
608
+ print(f"Loading embedding model: {model_name}...")
609
+ print(f"Embedder will be cached in: {embedder_dir}")
610
+
611
+ # Check if embedder exists locally (check for actual model files, not just config)
612
+ config_path = os.path.join(embedder_dir, "config.json")
613
+ has_model_files = False
614
+ if os.path.exists(config_path):
615
+ # Check if model files exist
616
+ model_files = glob.glob(os.path.join(embedder_dir, "*.safetensors")) + \
617
+ glob.glob(os.path.join(embedder_dir, "pytorch_model.bin"))
618
+ if model_files or os.path.exists(os.path.join(embedder_dir, "model.safetensors.index.json")):
619
+ has_model_files = True
620
+
621
+ if has_model_files:
622
+ print(f"Loading embedder from local cache: {embedder_dir}")
623
+ self.embedder = SentenceTransformer(embedder_dir)
624
+ else:
625
+ print("Downloading embedder from HuggingFace...")
626
+ self.embedder = SentenceTransformer(model_name, cache_folder=embedder_dir)
627
+ # Try to save to local directory (but don't fail if it doesn't work)
628
+ try:
629
+ self.embedder.save(embedder_dir)
630
+ print(f"[OK] Embedder saved to {embedder_dir}")
631
+ except ImportError as import_err:
632
+ if "DTensor" in str(import_err):
633
+ print(f"[!] Warning: Could not save embedder due to PyTorch/transformers compatibility issue: {import_err}")
634
+ print("This is a known issue with certain versions. Embedder will work but won't be saved locally.")
635
+ print("Continuing without local save...")
636
+ else:
637
+ print(f"[!] Warning: Could not save embedder: {import_err}")
638
+ except Exception as save_err:
639
+ print(f"[!] Warning: Could not save embedder locally: {save_err}")
640
+ print("Continuing without local save...")
641
+
642
+ print("[OK] Embedding model loaded!")
643
+
644
+ def build_faiss_index(self, documents: List[str], metadata: List[Dict] = None):
645
+ """
646
+ Build FAISS index from documents
647
+
648
+ Args:
649
+ documents: List of text documents to index
650
+ metadata: Optional metadata for each document
651
+ """
652
+ if not self.embedder:
653
+ self.load_embedder()
654
+
655
+ print(f"Building FAISS index for {len(documents)} documents...")
656
+
657
+ # Generate embeddings
658
+ embeddings = self.embedder.encode(documents, show_progress_bar=True)
659
+ embeddings = np.array(embeddings).astype('float32')
660
+
661
+ # Get dimension
662
+ dimension = embeddings.shape[1]
663
+
664
+ # Create FAISS index (L2 distance)
665
+ self.index = faiss.IndexFlatL2(dimension)
666
+
667
+ # Add embeddings to index
668
+ self.index.add(embeddings)
669
+
670
+ # Store documents and metadata
671
+ self.documents = documents
672
+ self.metadata = metadata if metadata else [{}] * len(documents)
673
+
674
+ print(f"[OK] FAISS index built with {self.index.ntotal} vectors")
675
+
676
+ def build_index_from_json(self, json_data: Dict[str, Any]):
677
+ """Build FAISS index from exported JSON data"""
678
+ documents = []
679
+ metadata = []
680
+
681
+ # Add room documents
682
+ for room in json_data.get("all_rooms", []):
683
+ # Create text representation
684
+ room_text = self._room_to_text(room)
685
+ documents.append(room_text)
686
+ metadata.append({
687
+ "type": "room",
688
+ "room_id": room.get("room_id"),
689
+ "data": room
690
+ })
691
+
692
+ # Add route documents
693
+ for route in json_data.get("evacuation_routes", {}).get("routes", []):
694
+ route_text = self._route_to_text(route)
695
+ documents.append(route_text)
696
+ metadata.append({
697
+ "type": "route",
698
+ "route_id": route.get("route_id"),
699
+ "exit": route.get("exit"),
700
+ "data": route
701
+ })
702
+
703
+ # Build index
704
+ self.build_faiss_index(documents, metadata)
705
+
706
+ def _room_to_text(self, room: Dict[str, Any]) -> str:
707
+ """Convert room data to searchable text"""
708
+ sensor = room.get("sensor_data", {})
709
+
710
+ text_parts = [
711
+ f"Room {room.get('room_id')} ({room.get('name')})",
712
+ f"Type: {room.get('room_type')}",
713
+ ]
714
+
715
+ if room.get("has_oxygen_cylinder"):
716
+ text_parts.append("[!]️ OXYGEN CYLINDER PRESENT - EXPLOSION RISK")
717
+
718
+ if sensor.get("fire_detected"):
719
+ text_parts.append("[FIRE] FIRE DETECTED")
720
+
721
+ text_parts.extend([
722
+ f"Temperature: {sensor.get('temperature_c')}°C",
723
+ f"Smoke level: {sensor.get('smoke_level')}",
724
+ f"Oxygen: {sensor.get('oxygen_pct')}%",
725
+ f"Visibility: {sensor.get('visibility_pct')}%",
726
+ f"Structural integrity: {sensor.get('structural_integrity_pct')}%",
727
+ f"Danger score: {sensor.get('danger_score')}",
728
+ f"Passable: {sensor.get('passable')}"
729
+ ])
730
+
731
+ if sensor.get("carbon_monoxide_ppm", 0) > 50:
732
+ text_parts.append(f"[!]️ HIGH CARBON MONOXIDE: {sensor.get('carbon_monoxide_ppm')} ppm")
733
+
734
+ if sensor.get("flashover_risk", 0) > 0.5:
735
+ text_parts.append(f"[!]️ FLASHOVER RISK: {sensor.get('flashover_risk')*100:.0f}%")
736
+
737
+ if not sensor.get("exit_accessible", True):
738
+ text_parts.append("[!]️ EXIT BLOCKED")
739
+
740
+ if sensor.get("occupancy_density", 0) > 0.7:
741
+ text_parts.append(f"[!]️ HIGH CROWD DENSITY: {sensor.get('occupancy_density')*100:.0f}%")
742
+
743
+ return " | ".join(text_parts)
744
+
745
+ def _route_to_text(self, route: Dict[str, Any]) -> str:
746
+ """Convert route data to searchable text"""
747
+ metrics = route.get("metrics", {})
748
+
749
+ text_parts = [
750
+ f"{route.get('route_id')} to {route.get('exit')}",
751
+ f"Path: {' → '.join(route.get('path', []))}",
752
+ f"Average danger: {metrics.get('avg_danger')}",
753
+ f"Max danger: {metrics.get('max_danger')} at {metrics.get('max_danger_location')}",
754
+ f"Passable: {metrics.get('passable')}",
755
+ f"Has fire: {metrics.get('has_fire')}",
756
+ f"Has oxygen hazard: {metrics.get('has_oxygen_hazard')}"
757
+ ]
758
+
759
+ risk_factors = metrics.get("risk_factors", [])
760
+ if risk_factors:
761
+ text_parts.append(f"Risks: {', '.join(risk_factors[:3])}")
762
+
763
+ return " | ".join(text_parts)
764
+
765
+ def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
766
+ """
767
+ Search FAISS index for relevant documents
768
+
769
+ Args:
770
+ query: Search query
771
+ k: Number of results to return
772
+
773
+ Returns:
774
+ List of relevant documents with metadata
775
+ """
776
+ if not self.index or not self.embedder:
777
+ raise ValueError("Index not built. Call build_faiss_index() first.")
778
+
779
+ # Encode query
780
+ query_embedding = self.embedder.encode([query])
781
+ query_embedding = np.array(query_embedding).astype('float32')
782
+
783
+ # Search
784
+ distances, indices = self.index.search(query_embedding, k)
785
+
786
+ # Return results
787
+ results = []
788
+ for i, idx in enumerate(indices[0]):
789
+ if idx < len(self.documents):
790
+ results.append({
791
+ "document": self.documents[idx],
792
+ "metadata": self.metadata[idx],
793
+ "distance": float(distances[0][i])
794
+ })
795
+
796
+ return results
797
+
798
+ def _build_cot_prompt(self, query: str, context: List[str]) -> str:
799
+ """Build Chain-of-Thought prompt with step-by-step reasoning"""
800
+ context_text = "\n".join([f"- {ctx}" for ctx in context])
801
+
802
+ prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to answer the question concisely.
803
+
804
+ CONTEXT:
805
+ {context_text}
806
+
807
+ QUESTION: {query}
808
+
809
+ Think step by step, then provide a brief answer:
810
+
811
+ REASONING:
812
+ 1. Analyze available information
813
+ 2. Identify key safety factors
814
+ 3. Evaluate risks and prioritize
815
+ 4. Conclude with recommendation
816
+
817
+ ANSWER:"""
818
+ return prompt
819
+
820
+ def _build_tot_prompt(self, query: str, context: List[str], thought: str = "") -> str:
821
+ """Build Tree-of-Thoughts prompt for exploring multiple reasoning paths"""
822
+ context_text = "\n".join([f"- {ctx}" for ctx in context])
823
+
824
+ if not thought:
825
+ prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to explore different reasoning approaches.
826
+
827
+ CONTEXT:
828
+ {context_text}
829
+
830
+ QUESTION: {query}
831
+
832
+ Let's explore different reasoning approaches to solve this problem:
833
+
834
+ APPROACH 1 - Safety-First Analysis:
835
+ """
836
+ else:
837
+ prompt = f"""CONTEXT:
838
+ {context_text}
839
+
840
+ QUESTION: {query}
841
+
842
+ CURRENT THOUGHT: {thought}
843
+
844
+ Evaluate this thought:
845
+ - Is this reasoning sound?
846
+ - What are the strengths and weaknesses?
847
+ - What alternative approaches should we consider?
848
+
849
+ EVALUATION:
850
+ """
851
+ return prompt
852
+
853
+ def _build_reflexion_prompt(self, query: str, context: List[str], previous_answer: str = "",
854
+ reflection: str = "") -> str:
855
+ """Build Reflexion prompt for self-reflection and improvement"""
856
+ context_text = "\n".join([f"- {ctx}" for ctx in context])
857
+
858
+ if not previous_answer:
859
+ # Initial answer
860
+ prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to answer the question.
861
+
862
+ CONTEXT:
863
+ {context_text}
864
+
865
+ QUESTION: {query}
866
+
867
+ Provide a clear, safety-focused answer based on the context.
868
+
869
+ ANSWER:"""
870
+ else:
871
+ # Reflection phase
872
+ prompt = f"""You are an expert fire evacuation safety advisor. Review and improve your previous answer.
873
+
874
+ CONTEXT:
875
+ {context_text}
876
+
877
+ QUESTION: {query}
878
+
879
+ PREVIOUS ANSWER:
880
+ {previous_answer}
881
+
882
+ REFLECTION:
883
+ {reflection}
884
+
885
+ Now provide an improved answer based on your reflection:
886
+
887
+ IMPROVED ANSWER:"""
888
+ return prompt
889
+
890
+ def _build_cot_with_tools_prompt(self, query: str, context: List[str], tool_results: List[str] = None) -> str:
891
+ """Build Chain-of-Thought prompt with tool integration"""
892
+ context_text = "\n".join([f"- {ctx}" for ctx in context])
893
+
894
+ tool_text = ""
895
+ if tool_results:
896
+ tool_text = "\nTOOL RESULTS:\n" + "\n".join([f"- {result}" for result in tool_results])
897
+
898
+ prompt = f"""You are an expert fire evacuation safety advisor. Use the following context and tool results to answer the question.
899
+
900
+ CONTEXT:
901
+ {context_text}
902
+ {tool_text}
903
+
904
+ QUESTION: {query}
905
+
906
+ Let's solve this step by step, using both the context and tool results:
907
+
908
+ STEP 1 - Understand the question and available data:
909
+ """
910
+ return prompt
911
+
912
+ def _generate_with_decoding_strategy(self, prompt: str, max_length: int = 500,
913
+ temperature: float = 0.7, top_p: float = 0.9,
914
+ num_beams: int = 3, stop_sequences: List[str] = None) -> str:
915
+ """Generate response using specified decoding strategy"""
916
+ if not self.pipe and not self.model:
917
+ raise ValueError("Model not loaded. Call download_model() first.")
918
+
919
+ try:
920
+ if self.use_unsloth and self.model:
921
+ inputs = self.tokenizer(
922
+ prompt,
923
+ return_tensors="pt",
924
+ truncation=True,
925
+ max_length=self.max_seq_length
926
+ ).to(self.model.device)
927
+
928
+ # Configure generation parameters based on decoding strategy
929
+ gen_kwargs = {
930
+ "max_new_tokens": max_length,
931
+ "pad_token_id": self.tokenizer.eos_token_id,
932
+ "eos_token_id": self.tokenizer.eos_token_id,
933
+ }
934
+
935
+ if self.decoding_strategy == DecodingStrategy.GREEDY:
936
+ gen_kwargs.update({
937
+ "do_sample": False,
938
+ "num_beams": 1
939
+ })
940
+ elif self.decoding_strategy == DecodingStrategy.SAMPLING:
941
+ gen_kwargs.update({
942
+ "do_sample": True,
943
+ "temperature": temperature,
944
+ "top_k": 50
945
+ })
946
+ elif self.decoding_strategy == DecodingStrategy.BEAM_SEARCH:
947
+ gen_kwargs.update({
948
+ "do_sample": False,
949
+ "num_beams": num_beams,
950
+ "early_stopping": True
951
+ })
952
+ elif self.decoding_strategy == DecodingStrategy.NUCLEUS:
953
+ gen_kwargs.update({
954
+ "do_sample": True,
955
+ "temperature": temperature,
956
+ "top_p": top_p,
957
+ "top_k": 0
958
+ })
959
+ elif self.decoding_strategy == DecodingStrategy.TEMPERATURE:
960
+ gen_kwargs.update({
961
+ "do_sample": True,
962
+ "temperature": temperature
963
+ })
964
+
965
+ with torch.no_grad():
966
+ outputs = self.model.generate(**inputs, **gen_kwargs)
967
+
968
+ response = self.tokenizer.batch_decode(
969
+ outputs,
970
+ skip_special_tokens=True
971
+ )[0]
972
+
973
+ # Extract response after prompt
974
+ if prompt in response:
975
+ response = response.split(prompt)[-1].strip()
976
+
977
+ # Post-process to stop at verbose endings
978
+ stop_phrases = [
979
+ "\n\nHowever, please note",
980
+ "\n\nAdditionally,",
981
+ "\n\nLet me know",
982
+ "\n\nIf you have",
983
+ "\n\nHere's another",
984
+ "\n\nQUESTION:",
985
+ "\n\nLet's break",
986
+ "\n\nHave a great day",
987
+ "\n\nI'm here to help"
988
+ ]
989
+ for phrase in stop_phrases:
990
+ if phrase in response:
991
+ response = response.split(phrase)[0].strip()
992
+ break
993
+
994
+ return response
995
+ else:
996
+ # Use pipeline for standard models
997
+ gen_kwargs = {
998
+ "max_length": len(self.tokenizer.encode(prompt)) + max_length,
999
+ "num_return_sequences": 1,
1000
+ }
1001
+
1002
+ if self.decoding_strategy == DecodingStrategy.GREEDY:
1003
+ gen_kwargs.update({
1004
+ "do_sample": False
1005
+ })
1006
+ elif self.decoding_strategy == DecodingStrategy.SAMPLING:
1007
+ gen_kwargs.update({
1008
+ "do_sample": True,
1009
+ "temperature": temperature,
1010
+ "top_k": 50
1011
+ })
1012
+ elif self.decoding_strategy == DecodingStrategy.BEAM_SEARCH:
1013
+ gen_kwargs.update({
1014
+ "do_sample": False,
1015
+ "num_beams": num_beams,
1016
+ "early_stopping": True
1017
+ })
1018
+ elif self.decoding_strategy == DecodingStrategy.NUCLEUS:
1019
+ gen_kwargs.update({
1020
+ "do_sample": True,
1021
+ "temperature": temperature,
1022
+ "top_p": top_p,
1023
+ "top_k": 0
1024
+ })
1025
+ elif self.decoding_strategy == DecodingStrategy.TEMPERATURE:
1026
+ gen_kwargs.update({
1027
+ "do_sample": True,
1028
+ "temperature": temperature
1029
+ })
1030
+
1031
+ gen_kwargs["pad_token_id"] = self.tokenizer.eos_token_id if self.tokenizer else None
1032
+
1033
+ outputs = self.pipe(prompt, **gen_kwargs)
1034
+ response = outputs[0]['generated_text']
1035
+
1036
+ # Extract response after prompt
1037
+ if prompt in response:
1038
+ response = response.split(prompt)[-1].strip()
1039
+
1040
+ # Post-process to stop at verbose endings
1041
+ stop_phrases = [
1042
+ "\n\nHowever, please note",
1043
+ "\n\nAdditionally,",
1044
+ "\n\nLet me know",
1045
+ "\n\nIf you have",
1046
+ "\n\nHere's another",
1047
+ "\n\nQUESTION:",
1048
+ "\n\nLet's break",
1049
+ "\n\nHave a great day",
1050
+ "\n\nI'm here to help"
1051
+ ]
1052
+ for phrase in stop_phrases:
1053
+ if phrase in response:
1054
+ response = response.split(phrase)[0].strip()
1055
+ break
1056
+
1057
+ return response
1058
+
1059
+ except Exception as e:
1060
+ return f"Error generating response: {e}"
1061
+
1062
+ def _chain_of_thought_reasoning(self, query: str, context: List[str], max_length: int = 500) -> Tuple[str, str]:
1063
+ """Generate response using Chain-of-Thought reasoning
1064
+
1065
+ Returns:
1066
+ Tuple of (full_reasoning, final_answer)
1067
+ """
1068
+ prompt = self._build_cot_prompt(query, context)
1069
+ # Use shorter max_length for CoT to prevent verbosity
1070
+ full_response = self._generate_with_decoding_strategy(prompt, max_length=min(max_length, 300))
1071
+
1072
+ # Extract reasoning steps (everything before ANSWER)
1073
+ reasoning = ""
1074
+ if "REASONING:" in full_response:
1075
+ reasoning_parts = full_response.split("REASONING:")
1076
+ if len(reasoning_parts) > 1:
1077
+ reasoning_section = reasoning_parts[1].split("ANSWER:")[0] if "ANSWER:" in reasoning_parts[1] else reasoning_parts[1]
1078
+ reasoning = reasoning_section.strip()
1079
+ elif "ANSWER:" in full_response:
1080
+ reasoning = full_response.split("ANSWER:")[0].strip()
1081
+ else:
1082
+ # Try to extract reasoning from numbered steps
1083
+ lines = full_response.split('\n')
1084
+ reasoning_lines = []
1085
+ for line in lines:
1086
+ if line.strip().startswith(('1.', '2.', '3.', '4.', '5.', 'Step', 'STEP')):
1087
+ reasoning_lines.append(line.strip())
1088
+ elif "ANSWER" in line.upper():
1089
+ break
1090
+ elif reasoning_lines: # Continue collecting if we've started
1091
+ reasoning_lines.append(line.strip())
1092
+ reasoning = '\n'.join(reasoning_lines)
1093
+
1094
+ # Extract final answer (everything after ANSWER:)
1095
+ final_answer = full_response
1096
+ if "ANSWER:" in full_response:
1097
+ answer_parts = full_response.split("ANSWER:")
1098
+ if len(answer_parts) > 1:
1099
+ answer_text = answer_parts[-1].strip()
1100
+ # Stop at common continuation markers
1101
+ stop_markers = [
1102
+ "\n\nHowever, please note",
1103
+ "\n\nAdditionally,",
1104
+ "\n\nLet me know",
1105
+ "\n\nIf you have",
1106
+ "\n\nHere's another",
1107
+ "\n\nQUESTION:",
1108
+ "\n\nLet's break",
1109
+ "\n\nHave a great day",
1110
+ "\n\nI'm here to help",
1111
+ "\n\nThese general guidelines",
1112
+ "\n\nIf you have any further"
1113
+ ]
1114
+ for marker in stop_markers:
1115
+ if marker in answer_text:
1116
+ answer_text = answer_text.split(marker)[0].strip()
1117
+ break
1118
+ # Also limit to first 2-3 sentences if it's still too long
1119
+ sentences = answer_text.split('. ')
1120
+ if len(sentences) > 3:
1121
+ answer_text = '. '.join(sentences[:3])
1122
+ if not answer_text.endswith('.'):
1123
+ answer_text += '.'
1124
+ final_answer = answer_text
1125
+
1126
+ # Clean up reasoning - remove verbose parts
1127
+ if reasoning:
1128
+ # Remove common verbose endings
1129
+ verbose_endings = [
1130
+ "However, please note",
1131
+ "Additionally,",
1132
+ "Let me know",
1133
+ "If you have",
1134
+ "Here's another",
1135
+ "Have a great day",
1136
+ "I'm here to help"
1137
+ ]
1138
+ for ending in verbose_endings:
1139
+ if ending in reasoning:
1140
+ reasoning = reasoning.split(ending)[0].strip()
1141
+ break
1142
+
1143
+ return reasoning or "Reasoning steps generated", final_answer
1144
+
1145
+ def _tree_of_thoughts_reasoning(self, query: str, context: List[str], max_length: int = 500,
1146
+ max_thoughts: int = 3) -> Tuple[str, str]:
1147
+ """Generate response using Tree-of-Thoughts reasoning
1148
+
1149
+ Returns:
1150
+ Tuple of (full_reasoning, final_answer)
1151
+ """
1152
+ thoughts = []
1153
+ reasoning_log = []
1154
+
1155
+ # Generate initial thoughts
1156
+ for i in range(max_thoughts):
1157
+ thought_prompt = self._build_tot_prompt(query, context,
1158
+ thought=f"Exploring approach {i+1}")
1159
+ thought = self._generate_with_decoding_strategy(thought_prompt, max_length // max_thoughts)
1160
+ thoughts.append(thought)
1161
+ reasoning_log.append(f"APPROACH {i+1}:\n{thought}\n")
1162
+
1163
+ # Evaluate thoughts and select best
1164
+ evaluation_prompt = f"""Evaluate these different reasoning approaches for answering the question:
1165
+
1166
+ QUESTION: {query}
1167
+
1168
+ APPROACHES:
1169
+ """
1170
+ for i, thought in enumerate(thoughts, 1):
1171
+ evaluation_prompt += f"\nAPPROACH {i}:\n{thought}\n"
1172
+
1173
+ evaluation_prompt += "\nWhich approach is most sound and complete? Provide the best answer based on the evaluation.\n\nBEST ANSWER:"
1174
+
1175
+ final_response = self._generate_with_decoding_strategy(evaluation_prompt, max_length)
1176
+
1177
+ full_reasoning = "\n".join(reasoning_log) + f"\n\nEVALUATION:\n{final_response}"
1178
+ return full_reasoning, final_response
1179
+
1180
+ def _reflexion_reasoning(self, query: str, context: List[str], max_length: int = 500,
1181
+ max_iterations: int = 2) -> Tuple[str, str]:
1182
+ """Generate response using Reflexion (self-reflection and improvement)
1183
+
1184
+ Returns:
1185
+ Tuple of (full_reasoning, final_answer)
1186
+ """
1187
+ reasoning_log = []
1188
+
1189
+ # Initial answer
1190
+ initial_prompt = self._build_reflexion_prompt(query, context)
1191
+ answer = self._generate_with_decoding_strategy(initial_prompt, max_length)
1192
+ reasoning_log.append(f"INITIAL ANSWER:\n{answer}\n")
1193
+
1194
+ # Reflection and improvement iterations
1195
+ for iteration in range(max_iterations):
1196
+ # Generate reflection
1197
+ reflection_prompt = f"""Review this answer for a fire evacuation safety question:
1198
+
1199
+ QUESTION: {query}
1200
+
1201
+ CURRENT ANSWER:
1202
+ {answer}
1203
+
1204
+ What could be improved? Consider:
1205
+ - Accuracy of safety information
1206
+ - Completeness of the response
1207
+ - Clarity and actionability
1208
+ - Missing critical safety factors
1209
+
1210
+ REFLECTION:"""
1211
+
1212
+ reflection = self._generate_with_decoding_strategy(reflection_prompt, max_length // 2)
1213
+ reasoning_log.append(f"ITERATION {iteration + 1} - REFLECTION:\n{reflection}\n")
1214
+
1215
+ # Generate improved answer
1216
+ improved_prompt = self._build_reflexion_prompt(query, context, answer, reflection)
1217
+ improved_answer = self._generate_with_decoding_strategy(improved_prompt, max_length)
1218
+ reasoning_log.append(f"ITERATION {iteration + 1} - IMPROVED ANSWER:\n{improved_answer}\n")
1219
+
1220
+ # Check if improvement is significant (simple heuristic)
1221
+ if len(improved_answer) > len(answer) * 0.8: # At least 80% of original length
1222
+ answer = improved_answer
1223
+ else:
1224
+ break # Stop if answer becomes too short
1225
+
1226
+ self.reflexion_history.append({
1227
+ "query": query,
1228
+ "final_answer": answer,
1229
+ "iterations": iteration + 1
1230
+ })
1231
+
1232
+ full_reasoning = "\n".join(reasoning_log)
1233
+ return full_reasoning, answer
1234
+
1235
+ def _cot_with_tools_reasoning(self, query: str, context: List[str], max_length: int = 500) -> Tuple[str, str]:
1236
+ """Generate response using Chain-of-Thought with tool integration
1237
+
1238
+ Returns:
1239
+ Tuple of (full_reasoning, final_answer)
1240
+ """
1241
+ reasoning_log = []
1242
+
1243
+ # Simulate tool calls (in real implementation, these would call actual tools)
1244
+ tool_results = []
1245
+
1246
+ # Tool 1: Route analysis
1247
+ if "route" in query.lower() or "path" in query.lower():
1248
+ tool_result = "Tool: Route Analyzer - Found 3 evacuation routes with risk scores"
1249
+ tool_results.append(tool_result)
1250
+ reasoning_log.append(f"TOOL CALL: {tool_result}\n")
1251
+
1252
+ # Tool 2: Risk calculator
1253
+ if "danger" in query.lower() or "risk" in query.lower():
1254
+ tool_result = "Tool: Risk Calculator - Calculated danger scores for all rooms"
1255
+ tool_results.append(tool_result)
1256
+ reasoning_log.append(f"TOOL CALL: {tool_result}\n")
1257
+
1258
+ # Tool 3: Sensor aggregator
1259
+ if "sensor" in query.lower() or "temperature" in query.lower() or "smoke" in query.lower():
1260
+ tool_result = "Tool: Sensor Aggregator - Aggregated sensor data from all rooms"
1261
+ tool_results.append(tool_result)
1262
+ reasoning_log.append(f"TOOL CALL: {tool_result}\n")
1263
+
1264
+ prompt = self._build_cot_with_tools_prompt(query, context, tool_results)
1265
+ response = self._generate_with_decoding_strategy(prompt, max_length)
1266
+
1267
+ reasoning_log.append(f"REASONING WITH TOOLS:\n{response}\n")
1268
+ full_reasoning = "\n".join(reasoning_log)
1269
+
1270
+ # Extract final answer
1271
+ final_answer = response
1272
+ if "ANSWER:" in response or "answer:" in response.lower():
1273
+ parts = response.split("ANSWER:") if "ANSWER:" in response else response.split("answer:")
1274
+ if len(parts) > 1:
1275
+ final_answer = parts[-1].strip()
1276
+
1277
+ return full_reasoning, final_answer
1278
+
1279
+ def generate_response(self, query: str, context: List[str] = None, max_length: int = 500,
1280
+ return_reasoning: bool = False) -> str:
1281
+ """
1282
+ Generate response using Llama model with context and advanced reasoning
1283
+
1284
+ Args:
1285
+ query: User query
1286
+ context: Optional context strings (if None, will retrieve from FAISS)
1287
+ max_length: Maximum response length
1288
+ return_reasoning: If True, returns tuple of (reasoning, answer), else just answer
1289
+
1290
+ Returns:
1291
+ If return_reasoning is True: Tuple of (reasoning_steps, final_answer)
1292
+ Otherwise: Just the final answer string
1293
+ """
1294
+ if not self.pipe and not self.model:
1295
+ raise ValueError("Model not loaded. Call download_model() first.")
1296
+
1297
+ # Retrieve context if not provided
1298
+ if context is None:
1299
+ search_results = self.search(query, k=3)
1300
+ context = [r["document"] for r in search_results]
1301
+
1302
+ # Route to appropriate reasoning method based on mode
1303
+ if self.reasoning_mode == ReasoningMode.CHAIN_OF_THOUGHT:
1304
+ reasoning, answer = self._chain_of_thought_reasoning(query, context, max_length)
1305
+ elif self.reasoning_mode == ReasoningMode.TREE_OF_THOUGHTS:
1306
+ reasoning, answer = self._tree_of_thoughts_reasoning(query, context, max_length)
1307
+ elif self.reasoning_mode == ReasoningMode.REFLEXION:
1308
+ reasoning, answer = self._reflexion_reasoning(query, context, max_length)
1309
+ elif self.reasoning_mode == ReasoningMode.COT_WITH_TOOLS:
1310
+ reasoning, answer = self._cot_with_tools_reasoning(query, context, max_length)
1311
+ else:
1312
+ # Standard mode - use enhanced prompt with decoding strategy
1313
+ context_text = "\n".join([f"- {ctx}" for ctx in context])
1314
+
1315
+ prompt = f"""You are an expert fire evacuation safety advisor. Use the following context about the building's fire safety status to answer the question.
1316
+
1317
+ CONTEXT:
1318
+ {context_text}
1319
+
1320
+ QUESTION: {query}
1321
+
1322
+ Provide a clear, safety-focused answer based on the context. If the context doesn't contain enough information, say so.
1323
+
1324
+ ANSWER:"""
1325
+
1326
+ answer = self._generate_with_decoding_strategy(prompt, max_length)
1327
+ reasoning = f"Standard reasoning mode - Direct answer generation.\n\n{answer}"
1328
+
1329
+ if return_reasoning:
1330
+ return reasoning, answer
1331
+ return answer
1332
+
1333
+ def set_reasoning_mode(self, mode: ReasoningMode):
1334
+ """Set the reasoning mode for future queries"""
1335
+ self.reasoning_mode = mode
1336
+ print(f"[OK] Reasoning mode set to: {mode.value}")
1337
+
1338
+ def set_decoding_strategy(self, strategy: DecodingStrategy):
1339
+ """Set the decoding strategy for future queries"""
1340
+ self.decoding_strategy = strategy
1341
+ print(f"[OK] Decoding strategy set to: {strategy.value}")
1342
+
1343
+ def query(self, question: str, k: int = 3, reasoning_mode: Optional[ReasoningMode] = None,
1344
+ show_reasoning: bool = True) -> Dict[str, Any]:
1345
+ """
1346
+ Complete RAG query: retrieve context and generate response with advanced reasoning
1347
+
1348
+ Args:
1349
+ question: User question
1350
+ k: Number of context documents to retrieve
1351
+ reasoning_mode: Optional override for reasoning mode (uses instance default if None)
1352
+ show_reasoning: If True, includes full reasoning steps in response
1353
+
1354
+ Returns:
1355
+ Dictionary with answer, context, metadata, reasoning information, and reasoning steps
1356
+ """
1357
+ # Retrieve relevant context
1358
+ search_results = self.search(question, k=k)
1359
+
1360
+ # Generate response with reasoning
1361
+ context = [r["document"] for r in search_results]
1362
+
1363
+ # Temporarily override reasoning mode if provided
1364
+ original_mode = self.reasoning_mode
1365
+ if reasoning_mode is not None:
1366
+ self.reasoning_mode = reasoning_mode
1367
+
1368
+ try:
1369
+ reasoning, answer = self.generate_response(question, context, return_reasoning=True)
1370
+ finally:
1371
+ # Restore original mode
1372
+ self.reasoning_mode = original_mode
1373
+
1374
+ result = {
1375
+ "question": question,
1376
+ "answer": answer,
1377
+ "context": context,
1378
+ "reasoning_mode": self.reasoning_mode.value,
1379
+ "decoding_strategy": self.decoding_strategy.value,
1380
+ "sources": [
1381
+ {
1382
+ "type": r["metadata"].get("type"),
1383
+ "room_id": r["metadata"].get("room_id"),
1384
+ "route_id": r["metadata"].get("route_id"),
1385
+ "relevance_score": 1.0 / (1.0 + r["distance"])
1386
+ }
1387
+ for r in search_results
1388
+ ]
1389
+ }
1390
+
1391
+ if show_reasoning:
1392
+ result["reasoning_steps"] = reasoning
1393
+
1394
+ return result
1395
+
1396
+ def save_index(self, index_path: str, metadata_path: str):
1397
+ """Save FAISS index and metadata"""
1398
+ if self.index:
1399
+ faiss.write_index(self.index, index_path)
1400
+ with open(metadata_path, 'wb') as f:
1401
+ pickle.dump({
1402
+ "documents": self.documents,
1403
+ "metadata": self.metadata
1404
+ }, f)
1405
+ print(f"[OK] Saved index to {index_path} and metadata to {metadata_path}")
1406
+
1407
+ def load_index(self, index_path: str, metadata_path: str):
1408
+ """Load FAISS index and metadata"""
1409
+ self.index = faiss.read_index(index_path)
1410
+ with open(metadata_path, 'rb') as f:
1411
+ data = pickle.load(f)
1412
+ self.documents = data["documents"]
1413
+ self.metadata = data["metadata"]
1414
+ print(f"[OK] Loaded index with {self.index.ntotal} vectors")
1415
+
1416
+ def compare_reasoning_modes(self, question: str, k: int = 3) -> Dict[str, Any]:
1417
+ """
1418
+ Compare all reasoning modes for a given question
1419
+
1420
+ Args:
1421
+ question: User question
1422
+ k: Number of context documents to retrieve
1423
+
1424
+ Returns:
1425
+ Dictionary with answers from all reasoning modes
1426
+ """
1427
+ # Retrieve context once
1428
+ search_results = self.search(question, k=k)
1429
+ context = [r["document"] for r in search_results]
1430
+
1431
+ results = {
1432
+ "question": question,
1433
+ "context": context,
1434
+ "sources": [
1435
+ {
1436
+ "type": r["metadata"].get("type"),
1437
+ "room_id": r["metadata"].get("room_id"),
1438
+ "route_id": r["metadata"].get("route_id"),
1439
+ "relevance_score": 1.0 / (1.0 + r["distance"])
1440
+ }
1441
+ for r in search_results
1442
+ ],
1443
+ "answers": {}
1444
+ }
1445
+
1446
+ # Save original mode
1447
+ original_mode = self.reasoning_mode
1448
+
1449
+ # Test each reasoning mode
1450
+ for mode in ReasoningMode:
1451
+ try:
1452
+ self.reasoning_mode = mode
1453
+ reasoning, answer = self.generate_response(question, context, return_reasoning=True)
1454
+ results["answers"][mode.value] = {
1455
+ "answer": answer,
1456
+ "reasoning": reasoning,
1457
+ "length": len(answer)
1458
+ }
1459
+ except Exception as e:
1460
+ results["answers"][mode.value] = {
1461
+ "error": str(e)
1462
+ }
1463
+
1464
+ # Restore original mode
1465
+ self.reasoning_mode = original_mode
1466
+
1467
+ return results
1468
+
1469
+ # === Gradio integration ===
1470
+ _rag_instance: Optional[FireEvacuationRAG] = None
1471
+
1472
+
1473
+ def _init_rag() -> FireEvacuationRAG:
1474
+ """Initialize and cache the RAG system for Gradio use."""
1475
+ global _rag_instance
1476
+ if _rag_instance is not None:
1477
+ return _rag_instance
1478
+
1479
+ # Configuration (match original defaults, but without noisy prints)
1480
+ USE_UNSLOTH = True
1481
+ USE_8BIT = False
1482
+ UNSLOTH_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"
1483
+
1484
+ # Set model directory to absolute path
1485
+ MODEL_DIR = r"D:\github\cse499\models"
1486
+
1487
+ # Create fire evacuation system
1488
+ floor_plan = create_sample_floor_plan()
1489
+ sensor_system = create_sample_fire_scenario(floor_plan)
1490
+ pathfinder = PathFinder(floor_plan, sensor_system)
1491
+
1492
+ # Export data and build index
1493
+ exporter = FireEvacuationDataExporter(floor_plan, sensor_system, pathfinder)
1494
+ json_data = exporter.export_to_json("fire_evacuation_data.json", start_location="R1")
1495
+
1496
+ # Initialize RAG
1497
+ if USE_UNSLOTH:
1498
+ rag = FireEvacuationRAG(
1499
+ model_name=UNSLOTH_MODEL,
1500
+ model_dir=MODEL_DIR,
1501
+ use_unsloth=True,
1502
+ load_in_4bit=False,
1503
+ max_seq_length=2048,
1504
+ reasoning_mode=ReasoningMode.CHAIN_OF_THOUGHT,
1505
+ decoding_strategy=DecodingStrategy.NUCLEUS,
1506
+ )
1507
+ else:
1508
+ rag = FireEvacuationRAG(
1509
+ model_name="nvidia/Llama-3.1-Minitron-4B-Width-Base",
1510
+ model_dir=MODEL_DIR,
1511
+ use_8bit=USE_8BIT,
1512
+ reasoning_mode=ReasoningMode.CHAIN_OF_THOUGHT,
1513
+ decoding_strategy=DecodingStrategy.NUCLEUS,
1514
+ )
1515
+
1516
+ rag.download_model()
1517
+ rag.load_embedder()
1518
+ rag.build_index_from_json(json_data)
1519
+
1520
+ _rag_instance = rag
1521
+ return rag
1522
+
1523
+
1524
+ def gradio_answer(question: str) -> str:
1525
+ """Gradio callback: take a text question, return LLM/RAG answer."""
1526
+ question = (question or "").strip()
1527
+ if not question:
1528
+ return "Please enter a question about fire evacuation or building safety."
1529
+
1530
+ rag = _init_rag()
1531
+ result = rag.query(question, k=3, show_reasoning=False)
1532
+ return result.get("answer", "No answer generated.")
1533
+
1534
+
1535
+ if __name__ == "__main__":
1536
+ iface = gr.Interface(
1537
+ fn=gradio_answer,
1538
+ inputs=gr.Textbox(lines=3, label="Fire Evacuation Question"),
1539
+ outputs=gr.Textbox(lines=6, label="LLM Recommendation"),
1540
+ title="Fire Evacuation RAG Advisor",
1541
+ description="Ask about evacuation routes, dangers, and exits in the simulated building.",
1542
+ )
1543
+ iface.launch()
1544
+