Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,1544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Llama + FAISS RAG System for Fire Evacuation with Advanced Reasoning
|
| 3 |
+
|
| 4 |
+
This module implements a RAG (Retrieval-Augmented Generation) system for fire evacuation scenarios
|
| 5 |
+
with advanced LLM reasoning techniques including:
|
| 6 |
+
|
| 7 |
+
1. Chain-of-Thought (CoT) Prompting:
|
| 8 |
+
- Enables step-by-step reasoning through intermediate steps
|
| 9 |
+
- Improves complex problem-solving capabilities
|
| 10 |
+
- Reference: https://arxiv.org/pdf/2201.11903
|
| 11 |
+
|
| 12 |
+
2. Tree-of-Thoughts (ToT):
|
| 13 |
+
- Maintains multiple reasoning paths
|
| 14 |
+
- Self-evaluates progress through intermediate thoughts
|
| 15 |
+
- Enables deliberate reasoning process
|
| 16 |
+
- Reference: https://arxiv.org/pdf/2305.10601
|
| 17 |
+
|
| 18 |
+
3. Reflexion:
|
| 19 |
+
- Reinforces language-based agents through linguistic feedback
|
| 20 |
+
- Self-reflection and iterative improvement
|
| 21 |
+
- Reference: https://arxiv.org/pdf/2303.11366
|
| 22 |
+
|
| 23 |
+
4. CoT with Tools:
|
| 24 |
+
- Combines CoT prompting with external tools
|
| 25 |
+
- Interleaved reasoning and tool usage
|
| 26 |
+
- Reference: https://arxiv.org/pdf/2303.09014
|
| 27 |
+
|
| 28 |
+
5. Advanced Decoding Strategies:
|
| 29 |
+
- Greedy: Deterministic highest probability
|
| 30 |
+
- Sampling: Random sampling with temperature
|
| 31 |
+
- Beam Search: Explores multiple paths
|
| 32 |
+
- Nucleus (Top-p): Samples from top-p probability mass
|
| 33 |
+
- Temperature: Temperature-based sampling
|
| 34 |
+
|
| 35 |
+
Downloads Llama model, creates JSON dataset, builds FAISS index, and provides RAG querying
|
| 36 |
+
"""
|
| 37 |
+
import unsloth
|
| 38 |
+
import json
|
| 39 |
+
import os
|
| 40 |
+
import pickle
|
| 41 |
+
import glob
|
| 42 |
+
import re
|
| 43 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 44 |
+
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
from enum import Enum
|
| 47 |
+
import copy
|
| 48 |
+
|
| 49 |
+
import numpy as np
|
| 50 |
+
import faiss
|
| 51 |
+
import torch
|
| 52 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 53 |
+
from sentence_transformers import SentenceTransformer
|
| 54 |
+
import gradio as gr
|
| 55 |
+
|
| 56 |
+
# Project imports (use helper_files package)
|
| 57 |
+
from helper_files.floor_plan import create_sample_floor_plan, FloorPlan
|
| 58 |
+
from helper_files.sensor_system import create_sample_fire_scenario, SensorSystem
|
| 59 |
+
from helper_files.pathfinding import PathFinder
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class FireEvacuationDataExporter:
|
| 64 |
+
"""Exports fire evacuation system data to JSON format"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, floor_plan: FloorPlan, sensor_system: SensorSystem, pathfinder: PathFinder):
|
| 67 |
+
self.floor_plan = floor_plan
|
| 68 |
+
self.sensor_system = sensor_system
|
| 69 |
+
self.pathfinder = pathfinder
|
| 70 |
+
|
| 71 |
+
def export_room_data(self, room_id: str) -> Dict[str, Any]:
|
| 72 |
+
"""Export comprehensive room data to JSON"""
|
| 73 |
+
room = self.floor_plan.get_room(room_id)
|
| 74 |
+
sensor = self.sensor_system.get_sensor_reading(room_id)
|
| 75 |
+
|
| 76 |
+
if not room or not sensor:
|
| 77 |
+
return {}
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"room_id": room_id,
|
| 81 |
+
"name": room.name,
|
| 82 |
+
"room_type": room.room_type,
|
| 83 |
+
"position": room.position,
|
| 84 |
+
"size": room.size,
|
| 85 |
+
"has_oxygen_cylinder": room.has_oxygen_cylinder,
|
| 86 |
+
"has_fire_extinguisher": room.has_fire_extinguisher,
|
| 87 |
+
"connected_to": [conn[0] for conn in room.connected_to],
|
| 88 |
+
"sensor_data": {
|
| 89 |
+
"fire_detected": sensor.fire_detected,
|
| 90 |
+
"smoke_level": round(sensor.smoke_level, 2),
|
| 91 |
+
"temperature_c": round(sensor.temperature, 1),
|
| 92 |
+
"oxygen_pct": round(sensor.oxygen_level, 1),
|
| 93 |
+
"visibility_pct": round(sensor.visibility, 1),
|
| 94 |
+
"structural_integrity_pct": round(sensor.structural_integrity, 1),
|
| 95 |
+
"fire_growth_rate": round(sensor.fire_growth_rate, 2),
|
| 96 |
+
"flashover_risk": round(sensor.flashover_risk, 2),
|
| 97 |
+
"backdraft_risk": round(sensor.backdraft_risk, 2),
|
| 98 |
+
"heat_radiation": round(sensor.heat_radiation, 2),
|
| 99 |
+
"fire_type": sensor.fire_type,
|
| 100 |
+
"carbon_monoxide_ppm": round(sensor.carbon_monoxide, 1),
|
| 101 |
+
"carbon_dioxide_ppm": round(sensor.carbon_dioxide, 1),
|
| 102 |
+
"hydrogen_cyanide_ppm": round(sensor.hydrogen_cyanide, 2),
|
| 103 |
+
"hydrogen_chloride_ppm": round(sensor.hydrogen_chloride, 2),
|
| 104 |
+
"wind_direction": round(sensor.wind_direction, 1),
|
| 105 |
+
"wind_speed": round(sensor.wind_speed, 2),
|
| 106 |
+
"air_pressure": round(sensor.air_pressure, 2),
|
| 107 |
+
"humidity": round(sensor.humidity, 1),
|
| 108 |
+
"occupancy_density": round(sensor.occupancy_density, 2),
|
| 109 |
+
"mobility_limitations": sensor.mobility_limitations,
|
| 110 |
+
"panic_level": round(sensor.panic_level, 2),
|
| 111 |
+
"evacuation_progress": round(sensor.evacuation_progress, 1),
|
| 112 |
+
"sprinkler_active": sensor.sprinkler_active,
|
| 113 |
+
"emergency_lighting": sensor.emergency_lighting,
|
| 114 |
+
"elevator_available": sensor.elevator_available,
|
| 115 |
+
"stairwell_clear": sensor.stairwell_clear,
|
| 116 |
+
"exit_accessible": sensor.exit_accessible,
|
| 117 |
+
"exit_capacity": sensor.exit_capacity,
|
| 118 |
+
"ventilation_active": sensor.ventilation_active,
|
| 119 |
+
"time_since_fire_start": sensor.time_since_fire_start,
|
| 120 |
+
"estimated_time_to_exit": sensor.estimated_time_to_exit,
|
| 121 |
+
"emergency_comm_working": sensor.emergency_comm_working,
|
| 122 |
+
"wifi_signal_strength": round(sensor.wifi_signal_strength, 1),
|
| 123 |
+
"danger_score": round(sensor.calculate_danger_score(), 1),
|
| 124 |
+
"passable": sensor.is_passable()
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def export_route_data(self, start_location: str = "R1") -> Dict[str, Any]:
|
| 129 |
+
"""Export all evacuation routes with detailed information"""
|
| 130 |
+
routes = self.pathfinder.find_all_evacuation_routes(start_location)
|
| 131 |
+
|
| 132 |
+
route_data = {
|
| 133 |
+
"timestamp_sec": 0,
|
| 134 |
+
"start_location": start_location,
|
| 135 |
+
"total_routes": len(routes),
|
| 136 |
+
"routes": []
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
for idx, (exit_id, path, risk) in enumerate(routes, 1):
|
| 140 |
+
route_info = {
|
| 141 |
+
"route_id": f"Route {idx}",
|
| 142 |
+
"exit": exit_id,
|
| 143 |
+
"path": path,
|
| 144 |
+
"metrics": {
|
| 145 |
+
"avg_danger": round(risk['avg_danger'], 2),
|
| 146 |
+
"max_danger": round(risk['max_danger'], 2),
|
| 147 |
+
"max_danger_location": risk['max_danger_location'],
|
| 148 |
+
"total_danger": round(risk['total_danger'], 2),
|
| 149 |
+
"path_length": risk['path_length'],
|
| 150 |
+
"has_fire": risk['has_fire'],
|
| 151 |
+
"has_oxygen_hazard": risk['has_oxygen_hazard'],
|
| 152 |
+
"passable": risk['passable'],
|
| 153 |
+
"risk_factors": risk['risk_factors']
|
| 154 |
+
},
|
| 155 |
+
"nodes": []
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# Add detailed node information
|
| 159 |
+
for room_id in path:
|
| 160 |
+
node_data = self.export_room_data(room_id)
|
| 161 |
+
if node_data:
|
| 162 |
+
route_info["nodes"].append(node_data)
|
| 163 |
+
|
| 164 |
+
route_data["routes"].append(route_info)
|
| 165 |
+
|
| 166 |
+
return route_data
|
| 167 |
+
|
| 168 |
+
def export_all_rooms(self) -> List[Dict[str, Any]]:
|
| 169 |
+
"""Export all rooms as separate documents"""
|
| 170 |
+
all_rooms = []
|
| 171 |
+
for room_id in self.floor_plan.rooms:
|
| 172 |
+
room_data = self.export_room_data(room_id)
|
| 173 |
+
if room_data:
|
| 174 |
+
all_rooms.append(room_data)
|
| 175 |
+
return all_rooms
|
| 176 |
+
|
| 177 |
+
def export_to_json(self, output_path: str, start_location: str = "R1"):
|
| 178 |
+
"""Export complete dataset to JSON file"""
|
| 179 |
+
data = {
|
| 180 |
+
"floor_plan": {
|
| 181 |
+
"floor_name": self.floor_plan.floor_name,
|
| 182 |
+
"total_rooms": len(self.floor_plan.rooms),
|
| 183 |
+
"exits": self.floor_plan.exits
|
| 184 |
+
},
|
| 185 |
+
"all_rooms": self.export_all_rooms(),
|
| 186 |
+
"evacuation_routes": self.export_route_data(start_location)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 190 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 191 |
+
|
| 192 |
+
print(f"[OK] Exported data to {output_path}")
|
| 193 |
+
return data
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class ReasoningMode(Enum):
|
| 197 |
+
"""Enumeration of reasoning modes"""
|
| 198 |
+
STANDARD = "standard"
|
| 199 |
+
CHAIN_OF_THOUGHT = "chain_of_thought"
|
| 200 |
+
TREE_OF_THOUGHTS = "tree_of_thoughts"
|
| 201 |
+
REFLEXION = "reflexion"
|
| 202 |
+
COT_WITH_TOOLS = "cot_with_tools"
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DecodingStrategy(Enum):
|
| 206 |
+
"""Enumeration of decoding strategies"""
|
| 207 |
+
GREEDY = "greedy"
|
| 208 |
+
SAMPLING = "sampling"
|
| 209 |
+
BEAM_SEARCH = "beam_search"
|
| 210 |
+
NUCLEUS = "nucleus"
|
| 211 |
+
TEMPERATURE = "temperature"
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class FireEvacuationRAG:
|
| 215 |
+
"""RAG system using FAISS for retrieval and Llama for generation with advanced reasoning"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, model_name: str = "nvidia/Llama-3.1-Minitron-4B-Width-Base", model_dir: str = "./models",
|
| 218 |
+
use_8bit: bool = False, use_unsloth: bool = False, load_in_4bit: bool = True, max_seq_length: int = 2048,
|
| 219 |
+
reasoning_mode: ReasoningMode = ReasoningMode.CHAIN_OF_THOUGHT,
|
| 220 |
+
decoding_strategy: DecodingStrategy = DecodingStrategy.NUCLEUS):
|
| 221 |
+
self.model_name = model_name
|
| 222 |
+
self.model_dir = model_dir
|
| 223 |
+
self.local_model_path = os.path.join(model_dir, model_name.replace("/", "_"))
|
| 224 |
+
self.use_8bit = use_8bit
|
| 225 |
+
self.use_unsloth = use_unsloth
|
| 226 |
+
self.load_in_4bit = load_in_4bit
|
| 227 |
+
self.max_seq_length = max_seq_length
|
| 228 |
+
self.reasoning_mode = reasoning_mode
|
| 229 |
+
self.decoding_strategy = decoding_strategy
|
| 230 |
+
self.tokenizer = None
|
| 231 |
+
self.model = None
|
| 232 |
+
self.pipe = None
|
| 233 |
+
self.embedder = None
|
| 234 |
+
self.index = None
|
| 235 |
+
self.documents = []
|
| 236 |
+
self.metadata = []
|
| 237 |
+
self.reflexion_history = [] # Store reflection history for Reflexion
|
| 238 |
+
|
| 239 |
+
# Create model directory if it doesn't exist
|
| 240 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 241 |
+
os.makedirs(self.local_model_path, exist_ok=True)
|
| 242 |
+
|
| 243 |
+
print(f"Initializing RAG system with model: {model_name}")
|
| 244 |
+
print(f"Model will be saved to: {self.local_model_path}")
|
| 245 |
+
print(f"Reasoning mode: {reasoning_mode.value}")
|
| 246 |
+
print(f"Decoding strategy: {decoding_strategy.value}")
|
| 247 |
+
if use_unsloth:
|
| 248 |
+
print("[*] Unsloth enabled (faster loading and inference)")
|
| 249 |
+
if load_in_4bit:
|
| 250 |
+
print(" - 4-bit quantization enabled (very fast, low memory)")
|
| 251 |
+
elif use_8bit:
|
| 252 |
+
print("[!] 8-bit quantization enabled (faster loading, lower memory, slight quality trade-off)")
|
| 253 |
+
|
| 254 |
+
def _check_model_files_exist(self, model_path: str) -> bool:
|
| 255 |
+
"""Check if model files actually exist (not just config.json)"""
|
| 256 |
+
required_files = [
|
| 257 |
+
"config.json",
|
| 258 |
+
"model.safetensors.index.json" # Check for sharded model index
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
# Check for at least one model file
|
| 262 |
+
model_file_patterns = [
|
| 263 |
+
"model.safetensors",
|
| 264 |
+
"pytorch_model.bin",
|
| 265 |
+
"model-*.safetensors" # Sharded models
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
config_exists = os.path.exists(os.path.join(model_path, "config.json"))
|
| 269 |
+
if not config_exists:
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
# Check for model weight files
|
| 273 |
+
for pattern in model_file_patterns:
|
| 274 |
+
if glob.glob(os.path.join(model_path, pattern)):
|
| 275 |
+
return True
|
| 276 |
+
|
| 277 |
+
# Check for sharded model index
|
| 278 |
+
if os.path.exists(os.path.join(model_path, "model.safetensors.index.json")):
|
| 279 |
+
return True
|
| 280 |
+
|
| 281 |
+
return False
|
| 282 |
+
|
| 283 |
+
def download_model(self):
|
| 284 |
+
"""Download and load the Llama model, saving weights to local directory"""
|
| 285 |
+
print("Downloading Llama model (this may take a while)...")
|
| 286 |
+
print(f"Model weights will be saved to: {self.local_model_path}")
|
| 287 |
+
|
| 288 |
+
# Use Unsloth if enabled (much faster loading) - PRIMARY METHOD
|
| 289 |
+
if self.use_unsloth:
|
| 290 |
+
try:
|
| 291 |
+
from unsloth import FastLanguageModel
|
| 292 |
+
from transformers import TextStreamer
|
| 293 |
+
print("[*] Using Unsloth for fast model loading...")
|
| 294 |
+
|
| 295 |
+
# Check if model name indicates it's already quantized (contains "bnb-4bit" or "bnb-8bit")
|
| 296 |
+
is_pre_quantized = "bnb-4bit" in self.model_name.lower() or "bnb-8bit" in self.model_name.lower()
|
| 297 |
+
|
| 298 |
+
# For pre-quantized models, don't set load_in_4bit (model is already quantized)
|
| 299 |
+
# For non-quantized models, check if bitsandbytes is available
|
| 300 |
+
if self.load_in_4bit and not is_pre_quantized:
|
| 301 |
+
try:
|
| 302 |
+
import bitsandbytes
|
| 303 |
+
print("[OK] bitsandbytes available for 4-bit quantization")
|
| 304 |
+
except ImportError:
|
| 305 |
+
print("[!] bitsandbytes not found. 4-bit quantization requires bitsandbytes.")
|
| 306 |
+
print(" Install with: pip install bitsandbytes")
|
| 307 |
+
print(" Falling back to full precision...")
|
| 308 |
+
self.load_in_4bit = False
|
| 309 |
+
|
| 310 |
+
# Check if model exists locally
|
| 311 |
+
if self._check_model_files_exist(self.local_model_path):
|
| 312 |
+
print(f"Loading from local path: {self.local_model_path}")
|
| 313 |
+
model_path = self.local_model_path
|
| 314 |
+
else:
|
| 315 |
+
print(f"Downloading model: {self.model_name}")
|
| 316 |
+
model_path = self.model_name
|
| 317 |
+
|
| 318 |
+
# ==== Load Model with Unsloth (exact pattern from user) ====
|
| 319 |
+
dtype = None # Auto-detect dtype
|
| 320 |
+
|
| 321 |
+
# Try loading with proper error handling for bitsandbytes
|
| 322 |
+
# The model config might have quantization settings that trigger bitsandbytes check
|
| 323 |
+
max_retries = 2
|
| 324 |
+
for attempt in range(max_retries):
|
| 325 |
+
try:
|
| 326 |
+
# For pre-quantized models, don't specify load_in_4bit (it's already quantized)
|
| 327 |
+
if is_pre_quantized or attempt > 0:
|
| 328 |
+
print("[OK] Loading model without quantization parameters...")
|
| 329 |
+
# Don't pass any quantization parameters
|
| 330 |
+
load_kwargs = {
|
| 331 |
+
"model_name": model_path,
|
| 332 |
+
"max_seq_length": self.max_seq_length,
|
| 333 |
+
"dtype": dtype,
|
| 334 |
+
}
|
| 335 |
+
else:
|
| 336 |
+
# For non-quantized models, try quantization if requested
|
| 337 |
+
load_kwargs = {
|
| 338 |
+
"model_name": model_path,
|
| 339 |
+
"max_seq_length": self.max_seq_length,
|
| 340 |
+
"dtype": dtype,
|
| 341 |
+
}
|
| 342 |
+
if self.load_in_4bit:
|
| 343 |
+
load_kwargs["load_in_4bit"] = True
|
| 344 |
+
|
| 345 |
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(**load_kwargs)
|
| 346 |
+
break # Success, exit retry loop
|
| 347 |
+
|
| 348 |
+
except (ImportError, Exception) as quant_error:
|
| 349 |
+
error_str = str(quant_error)
|
| 350 |
+
is_bitsandbytes_error = (
|
| 351 |
+
"bitsandbytes" in error_str.lower() or
|
| 352 |
+
"PackageNotFoundError" in error_str or
|
| 353 |
+
"No package metadata" in error_str or
|
| 354 |
+
"quantization_config" in error_str.lower()
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if is_bitsandbytes_error and attempt < max_retries - 1:
|
| 358 |
+
print(f"[!] Attempt {attempt + 1}: bitsandbytes error detected.")
|
| 359 |
+
print(f" Error: {error_str[:150]}...")
|
| 360 |
+
print(" Retrying without quantization parameters...")
|
| 361 |
+
continue # Retry without quantization
|
| 362 |
+
elif is_bitsandbytes_error:
|
| 363 |
+
print("[!] bitsandbytes required but not installed.")
|
| 364 |
+
print(" Options:")
|
| 365 |
+
print(" 1. Install bitsandbytes: pip install bitsandbytes")
|
| 366 |
+
print(" 2. Use a non-quantized model")
|
| 367 |
+
print(" 3. Set USE_UNSLOTH=False to use standard loading")
|
| 368 |
+
raise ImportError(
|
| 369 |
+
"bitsandbytes is required for this model. "
|
| 370 |
+
"Install with: pip install bitsandbytes"
|
| 371 |
+
) from quant_error
|
| 372 |
+
else:
|
| 373 |
+
# Re-raise if it's a different error
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
# Optimize for inference
|
| 377 |
+
FastLanguageModel.for_inference(self.model)
|
| 378 |
+
|
| 379 |
+
print("[OK] Model loaded successfully with Unsloth!")
|
| 380 |
+
|
| 381 |
+
# Verify device
|
| 382 |
+
if torch.cuda.is_available():
|
| 383 |
+
actual_device = next(self.model.parameters()).device
|
| 384 |
+
print(f"[OK] Model loaded on {actual_device}!")
|
| 385 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 386 |
+
print(f"[OK] GPU Memory allocated: {allocated:.2f} GB")
|
| 387 |
+
else:
|
| 388 |
+
print("[OK] Model loaded on CPU!")
|
| 389 |
+
|
| 390 |
+
# Set pipe to model for compatibility (we'll use model directly in generation)
|
| 391 |
+
self.pipe = self.model # Store model reference for compatibility checks
|
| 392 |
+
|
| 393 |
+
return # Exit early, Unsloth loading complete
|
| 394 |
+
|
| 395 |
+
except ImportError:
|
| 396 |
+
print("[!] Unsloth not installed. Falling back to standard loading.")
|
| 397 |
+
print(" Install with: pip install unsloth")
|
| 398 |
+
self.use_unsloth = False # Disable unsloth for this session
|
| 399 |
+
except Exception as e:
|
| 400 |
+
print(f"[!] Unsloth loading failed: {e}")
|
| 401 |
+
print(" Falling back to standard loading...")
|
| 402 |
+
self.use_unsloth = False
|
| 403 |
+
|
| 404 |
+
# Standard loading (original code)
|
| 405 |
+
# Check GPU availability and optimize settings
|
| 406 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 407 |
+
if device == "cuda":
|
| 408 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 409 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 410 |
+
print(f"[OK] GPU detected: {gpu_name}")
|
| 411 |
+
print(f"[OK] GPU Memory: {gpu_memory:.2f} GB")
|
| 412 |
+
# Use bfloat16 for faster loading and inference on GPU
|
| 413 |
+
torch_dtype = torch.bfloat16
|
| 414 |
+
print("[OK] Using bfloat16 precision for faster loading")
|
| 415 |
+
else:
|
| 416 |
+
print("[!] No GPU detected, using CPU (will be slower)")
|
| 417 |
+
torch_dtype = torch.float32
|
| 418 |
+
print("[OK] Using float32 precision for CPU")
|
| 419 |
+
|
| 420 |
+
# Check for optimized attention implementation
|
| 421 |
+
try:
|
| 422 |
+
import flash_attn # noqa: F401
|
| 423 |
+
attn_impl = 'flash_attention_2'
|
| 424 |
+
print("[OK] FlashAttention2 available - using for optimal performance")
|
| 425 |
+
except ImportError:
|
| 426 |
+
attn_impl = 'sdpa' # Scaled Dot Product Attention (built into PyTorch)
|
| 427 |
+
print("[OK] Using SDPA (Scaled Dot Product Attention) for faster inference")
|
| 428 |
+
|
| 429 |
+
# Check for 8-bit quantization support
|
| 430 |
+
use_quantization = False
|
| 431 |
+
if self.use_8bit and device == "cuda":
|
| 432 |
+
try:
|
| 433 |
+
import bitsandbytes
|
| 434 |
+
use_quantization = True
|
| 435 |
+
print("[OK] 8-bit quantization available - will use for faster loading")
|
| 436 |
+
except ImportError:
|
| 437 |
+
print("[!] 8-bit requested but bitsandbytes not installed, using full precision")
|
| 438 |
+
print(" Install with: pip install bitsandbytes")
|
| 439 |
+
|
| 440 |
+
try:
|
| 441 |
+
# Check if model already exists locally with actual model files
|
| 442 |
+
if self._check_model_files_exist(self.local_model_path):
|
| 443 |
+
print(f"Found existing model at {self.local_model_path}, loading from local...")
|
| 444 |
+
model_path = self.local_model_path
|
| 445 |
+
load_from_local = True
|
| 446 |
+
else:
|
| 447 |
+
print("Downloading model from HuggingFace...")
|
| 448 |
+
model_path = self.model_name
|
| 449 |
+
load_from_local = False
|
| 450 |
+
|
| 451 |
+
# Load tokenizer
|
| 452 |
+
print("Loading tokenizer...")
|
| 453 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 454 |
+
model_path,
|
| 455 |
+
trust_remote_code=True
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Save tokenizer locally if downloaded (wrap in try-except to avoid crashes)
|
| 459 |
+
if not load_from_local:
|
| 460 |
+
try:
|
| 461 |
+
print("Saving tokenizer to local directory...")
|
| 462 |
+
self.tokenizer.save_pretrained(self.local_model_path)
|
| 463 |
+
print(f"[OK] Tokenizer saved to {self.local_model_path}")
|
| 464 |
+
except Exception as save_err:
|
| 465 |
+
print(f"[!] Warning: Could not save tokenizer locally: {save_err}")
|
| 466 |
+
print("Continuing without local save...")
|
| 467 |
+
|
| 468 |
+
# Load model with optimizations
|
| 469 |
+
print("Loading model with optimizations...")
|
| 470 |
+
load_kwargs = {
|
| 471 |
+
"trust_remote_code": True,
|
| 472 |
+
"low_cpu_mem_usage": True, # Reduces memory usage during loading
|
| 473 |
+
"_attn_implementation": attn_impl, # Optimized attention
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
# Add quantization or dtype
|
| 477 |
+
if use_quantization:
|
| 478 |
+
from transformers import BitsAndBytesConfig
|
| 479 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 480 |
+
load_in_8bit=True,
|
| 481 |
+
llm_int8_threshold=6.0
|
| 482 |
+
)
|
| 483 |
+
print("[OK] Using 8-bit quantization for faster loading and lower memory")
|
| 484 |
+
else:
|
| 485 |
+
load_kwargs["torch_dtype"] = torch_dtype
|
| 486 |
+
|
| 487 |
+
# Use device_map="auto" for GPU, manual placement for CPU
|
| 488 |
+
if device == "cuda":
|
| 489 |
+
try:
|
| 490 |
+
load_kwargs["device_map"] = "auto"
|
| 491 |
+
print("[OK] Using device_map='auto' for optimal GPU memory management")
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print(f"[!] device_map='auto' failed, using manual GPU placement: {e}")
|
| 494 |
+
load_kwargs.pop("device_map", None)
|
| 495 |
+
|
| 496 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 497 |
+
model_path,
|
| 498 |
+
**load_kwargs
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# Manual device placement if device_map wasn't used
|
| 502 |
+
if device == "cuda" and "device_map" not in load_kwargs:
|
| 503 |
+
self.model = self.model.cuda()
|
| 504 |
+
print("[OK] Model moved to GPU")
|
| 505 |
+
|
| 506 |
+
# Save model locally if downloaded (wrap in try-except to handle DTensor errors)
|
| 507 |
+
if not load_from_local:
|
| 508 |
+
try:
|
| 509 |
+
print("Saving model weights to local directory (this may take a while)...")
|
| 510 |
+
self.model.save_pretrained(
|
| 511 |
+
self.local_model_path,
|
| 512 |
+
safe_serialization=True # Use safetensors format
|
| 513 |
+
)
|
| 514 |
+
print(f"[OK] Model saved to {self.local_model_path}")
|
| 515 |
+
except ImportError as import_err:
|
| 516 |
+
if "DTensor" in str(import_err):
|
| 517 |
+
print(f"[!] Warning: Could not save model due to PyTorch/transformers compatibility issue: {import_err}")
|
| 518 |
+
print("This is a known issue with certain versions. Model will work but won't be saved locally.")
|
| 519 |
+
print("Continuing without local save...")
|
| 520 |
+
else:
|
| 521 |
+
raise
|
| 522 |
+
except Exception as save_err:
|
| 523 |
+
print(f"[!] Warning: Could not save model locally: {save_err}")
|
| 524 |
+
print("Continuing without local save...")
|
| 525 |
+
|
| 526 |
+
# Create pipeline with optimizations
|
| 527 |
+
print("Creating pipeline...")
|
| 528 |
+
pipeline_kwargs = {
|
| 529 |
+
"model": self.model,
|
| 530 |
+
"tokenizer": self.tokenizer,
|
| 531 |
+
}
|
| 532 |
+
if device == "cuda":
|
| 533 |
+
pipeline_kwargs["device_map"] = "auto"
|
| 534 |
+
|
| 535 |
+
self.pipe = pipeline("text-generation", **pipeline_kwargs)
|
| 536 |
+
|
| 537 |
+
# Verify model device
|
| 538 |
+
if device == "cuda":
|
| 539 |
+
actual_device = next(self.model.parameters()).device
|
| 540 |
+
print(f"[OK] Model loaded successfully on {actual_device}!")
|
| 541 |
+
if torch.cuda.is_available():
|
| 542 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 543 |
+
print(f"[OK] GPU Memory allocated: {allocated:.2f} GB")
|
| 544 |
+
else:
|
| 545 |
+
print("[OK] Model loaded successfully on CPU!")
|
| 546 |
+
|
| 547 |
+
except Exception as e:
|
| 548 |
+
print(f"Error loading model: {e}")
|
| 549 |
+
print("Falling back to pipeline-only loading...")
|
| 550 |
+
try:
|
| 551 |
+
# Determine device and dtype for fallback
|
| 552 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 553 |
+
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 554 |
+
|
| 555 |
+
# Try loading from local path first (only if model files actually exist)
|
| 556 |
+
if self._check_model_files_exist(self.local_model_path):
|
| 557 |
+
print(f"Attempting to load from local path: {self.local_model_path}")
|
| 558 |
+
pipeline_kwargs = {
|
| 559 |
+
"model": self.local_model_path,
|
| 560 |
+
"trust_remote_code": True,
|
| 561 |
+
"torch_dtype": torch_dtype,
|
| 562 |
+
}
|
| 563 |
+
if device == "cuda":
|
| 564 |
+
pipeline_kwargs["device_map"] = "auto"
|
| 565 |
+
self.pipe = pipeline("text-generation", **pipeline_kwargs)
|
| 566 |
+
# Extract tokenizer from pipeline if available
|
| 567 |
+
if hasattr(self.pipe, 'tokenizer'):
|
| 568 |
+
self.tokenizer = self.pipe.tokenizer
|
| 569 |
+
else:
|
| 570 |
+
print(f"Downloading model: {self.model_name}")
|
| 571 |
+
pipeline_kwargs = {
|
| 572 |
+
"model": self.model_name,
|
| 573 |
+
"trust_remote_code": True,
|
| 574 |
+
"torch_dtype": torch_dtype,
|
| 575 |
+
}
|
| 576 |
+
if device == "cuda":
|
| 577 |
+
pipeline_kwargs["device_map"] = "auto"
|
| 578 |
+
self.pipe = pipeline("text-generation", **pipeline_kwargs)
|
| 579 |
+
# Extract tokenizer from pipeline if available
|
| 580 |
+
if hasattr(self.pipe, 'tokenizer'):
|
| 581 |
+
self.tokenizer = self.pipe.tokenizer
|
| 582 |
+
|
| 583 |
+
# Try to save after loading (but don't fail if it doesn't work)
|
| 584 |
+
try:
|
| 585 |
+
if hasattr(self.pipe, 'model') and hasattr(self.pipe.model, 'save_pretrained'):
|
| 586 |
+
print("Attempting to save downloaded model to local directory...")
|
| 587 |
+
self.pipe.model.save_pretrained(self.local_model_path, safe_serialization=True)
|
| 588 |
+
if hasattr(self.pipe, 'tokenizer'):
|
| 589 |
+
self.pipe.tokenizer.save_pretrained(self.local_model_path)
|
| 590 |
+
print("[OK] Model saved successfully")
|
| 591 |
+
except ImportError as import_err:
|
| 592 |
+
if "DTensor" in str(import_err):
|
| 593 |
+
print(f"[!] Warning: Could not save model due to compatibility issue. Model will work but won't be saved locally.")
|
| 594 |
+
else:
|
| 595 |
+
print(f"[!] Warning: Could not save model: {import_err}")
|
| 596 |
+
except Exception as save_err:
|
| 597 |
+
print(f"[!] Warning: Could not save model locally: {save_err}")
|
| 598 |
+
|
| 599 |
+
except Exception as e2:
|
| 600 |
+
print(f"Pipeline loading also failed: {e2}")
|
| 601 |
+
raise
|
| 602 |
+
|
| 603 |
+
def load_embedder(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 604 |
+
"""Load sentence transformer for embeddings, saving to local directory"""
|
| 605 |
+
embedder_dir = os.path.join(self.model_dir, "embedder", model_name.replace("/", "_"))
|
| 606 |
+
os.makedirs(embedder_dir, exist_ok=True)
|
| 607 |
+
|
| 608 |
+
print(f"Loading embedding model: {model_name}...")
|
| 609 |
+
print(f"Embedder will be cached in: {embedder_dir}")
|
| 610 |
+
|
| 611 |
+
# Check if embedder exists locally (check for actual model files, not just config)
|
| 612 |
+
config_path = os.path.join(embedder_dir, "config.json")
|
| 613 |
+
has_model_files = False
|
| 614 |
+
if os.path.exists(config_path):
|
| 615 |
+
# Check if model files exist
|
| 616 |
+
model_files = glob.glob(os.path.join(embedder_dir, "*.safetensors")) + \
|
| 617 |
+
glob.glob(os.path.join(embedder_dir, "pytorch_model.bin"))
|
| 618 |
+
if model_files or os.path.exists(os.path.join(embedder_dir, "model.safetensors.index.json")):
|
| 619 |
+
has_model_files = True
|
| 620 |
+
|
| 621 |
+
if has_model_files:
|
| 622 |
+
print(f"Loading embedder from local cache: {embedder_dir}")
|
| 623 |
+
self.embedder = SentenceTransformer(embedder_dir)
|
| 624 |
+
else:
|
| 625 |
+
print("Downloading embedder from HuggingFace...")
|
| 626 |
+
self.embedder = SentenceTransformer(model_name, cache_folder=embedder_dir)
|
| 627 |
+
# Try to save to local directory (but don't fail if it doesn't work)
|
| 628 |
+
try:
|
| 629 |
+
self.embedder.save(embedder_dir)
|
| 630 |
+
print(f"[OK] Embedder saved to {embedder_dir}")
|
| 631 |
+
except ImportError as import_err:
|
| 632 |
+
if "DTensor" in str(import_err):
|
| 633 |
+
print(f"[!] Warning: Could not save embedder due to PyTorch/transformers compatibility issue: {import_err}")
|
| 634 |
+
print("This is a known issue with certain versions. Embedder will work but won't be saved locally.")
|
| 635 |
+
print("Continuing without local save...")
|
| 636 |
+
else:
|
| 637 |
+
print(f"[!] Warning: Could not save embedder: {import_err}")
|
| 638 |
+
except Exception as save_err:
|
| 639 |
+
print(f"[!] Warning: Could not save embedder locally: {save_err}")
|
| 640 |
+
print("Continuing without local save...")
|
| 641 |
+
|
| 642 |
+
print("[OK] Embedding model loaded!")
|
| 643 |
+
|
| 644 |
+
def build_faiss_index(self, documents: List[str], metadata: List[Dict] = None):
|
| 645 |
+
"""
|
| 646 |
+
Build FAISS index from documents
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
documents: List of text documents to index
|
| 650 |
+
metadata: Optional metadata for each document
|
| 651 |
+
"""
|
| 652 |
+
if not self.embedder:
|
| 653 |
+
self.load_embedder()
|
| 654 |
+
|
| 655 |
+
print(f"Building FAISS index for {len(documents)} documents...")
|
| 656 |
+
|
| 657 |
+
# Generate embeddings
|
| 658 |
+
embeddings = self.embedder.encode(documents, show_progress_bar=True)
|
| 659 |
+
embeddings = np.array(embeddings).astype('float32')
|
| 660 |
+
|
| 661 |
+
# Get dimension
|
| 662 |
+
dimension = embeddings.shape[1]
|
| 663 |
+
|
| 664 |
+
# Create FAISS index (L2 distance)
|
| 665 |
+
self.index = faiss.IndexFlatL2(dimension)
|
| 666 |
+
|
| 667 |
+
# Add embeddings to index
|
| 668 |
+
self.index.add(embeddings)
|
| 669 |
+
|
| 670 |
+
# Store documents and metadata
|
| 671 |
+
self.documents = documents
|
| 672 |
+
self.metadata = metadata if metadata else [{}] * len(documents)
|
| 673 |
+
|
| 674 |
+
print(f"[OK] FAISS index built with {self.index.ntotal} vectors")
|
| 675 |
+
|
| 676 |
+
def build_index_from_json(self, json_data: Dict[str, Any]):
|
| 677 |
+
"""Build FAISS index from exported JSON data"""
|
| 678 |
+
documents = []
|
| 679 |
+
metadata = []
|
| 680 |
+
|
| 681 |
+
# Add room documents
|
| 682 |
+
for room in json_data.get("all_rooms", []):
|
| 683 |
+
# Create text representation
|
| 684 |
+
room_text = self._room_to_text(room)
|
| 685 |
+
documents.append(room_text)
|
| 686 |
+
metadata.append({
|
| 687 |
+
"type": "room",
|
| 688 |
+
"room_id": room.get("room_id"),
|
| 689 |
+
"data": room
|
| 690 |
+
})
|
| 691 |
+
|
| 692 |
+
# Add route documents
|
| 693 |
+
for route in json_data.get("evacuation_routes", {}).get("routes", []):
|
| 694 |
+
route_text = self._route_to_text(route)
|
| 695 |
+
documents.append(route_text)
|
| 696 |
+
metadata.append({
|
| 697 |
+
"type": "route",
|
| 698 |
+
"route_id": route.get("route_id"),
|
| 699 |
+
"exit": route.get("exit"),
|
| 700 |
+
"data": route
|
| 701 |
+
})
|
| 702 |
+
|
| 703 |
+
# Build index
|
| 704 |
+
self.build_faiss_index(documents, metadata)
|
| 705 |
+
|
| 706 |
+
def _room_to_text(self, room: Dict[str, Any]) -> str:
|
| 707 |
+
"""Convert room data to searchable text"""
|
| 708 |
+
sensor = room.get("sensor_data", {})
|
| 709 |
+
|
| 710 |
+
text_parts = [
|
| 711 |
+
f"Room {room.get('room_id')} ({room.get('name')})",
|
| 712 |
+
f"Type: {room.get('room_type')}",
|
| 713 |
+
]
|
| 714 |
+
|
| 715 |
+
if room.get("has_oxygen_cylinder"):
|
| 716 |
+
text_parts.append("[!]️ OXYGEN CYLINDER PRESENT - EXPLOSION RISK")
|
| 717 |
+
|
| 718 |
+
if sensor.get("fire_detected"):
|
| 719 |
+
text_parts.append("[FIRE] FIRE DETECTED")
|
| 720 |
+
|
| 721 |
+
text_parts.extend([
|
| 722 |
+
f"Temperature: {sensor.get('temperature_c')}°C",
|
| 723 |
+
f"Smoke level: {sensor.get('smoke_level')}",
|
| 724 |
+
f"Oxygen: {sensor.get('oxygen_pct')}%",
|
| 725 |
+
f"Visibility: {sensor.get('visibility_pct')}%",
|
| 726 |
+
f"Structural integrity: {sensor.get('structural_integrity_pct')}%",
|
| 727 |
+
f"Danger score: {sensor.get('danger_score')}",
|
| 728 |
+
f"Passable: {sensor.get('passable')}"
|
| 729 |
+
])
|
| 730 |
+
|
| 731 |
+
if sensor.get("carbon_monoxide_ppm", 0) > 50:
|
| 732 |
+
text_parts.append(f"[!]️ HIGH CARBON MONOXIDE: {sensor.get('carbon_monoxide_ppm')} ppm")
|
| 733 |
+
|
| 734 |
+
if sensor.get("flashover_risk", 0) > 0.5:
|
| 735 |
+
text_parts.append(f"[!]️ FLASHOVER RISK: {sensor.get('flashover_risk')*100:.0f}%")
|
| 736 |
+
|
| 737 |
+
if not sensor.get("exit_accessible", True):
|
| 738 |
+
text_parts.append("[!]️ EXIT BLOCKED")
|
| 739 |
+
|
| 740 |
+
if sensor.get("occupancy_density", 0) > 0.7:
|
| 741 |
+
text_parts.append(f"[!]️ HIGH CROWD DENSITY: {sensor.get('occupancy_density')*100:.0f}%")
|
| 742 |
+
|
| 743 |
+
return " | ".join(text_parts)
|
| 744 |
+
|
| 745 |
+
def _route_to_text(self, route: Dict[str, Any]) -> str:
|
| 746 |
+
"""Convert route data to searchable text"""
|
| 747 |
+
metrics = route.get("metrics", {})
|
| 748 |
+
|
| 749 |
+
text_parts = [
|
| 750 |
+
f"{route.get('route_id')} to {route.get('exit')}",
|
| 751 |
+
f"Path: {' → '.join(route.get('path', []))}",
|
| 752 |
+
f"Average danger: {metrics.get('avg_danger')}",
|
| 753 |
+
f"Max danger: {metrics.get('max_danger')} at {metrics.get('max_danger_location')}",
|
| 754 |
+
f"Passable: {metrics.get('passable')}",
|
| 755 |
+
f"Has fire: {metrics.get('has_fire')}",
|
| 756 |
+
f"Has oxygen hazard: {metrics.get('has_oxygen_hazard')}"
|
| 757 |
+
]
|
| 758 |
+
|
| 759 |
+
risk_factors = metrics.get("risk_factors", [])
|
| 760 |
+
if risk_factors:
|
| 761 |
+
text_parts.append(f"Risks: {', '.join(risk_factors[:3])}")
|
| 762 |
+
|
| 763 |
+
return " | ".join(text_parts)
|
| 764 |
+
|
| 765 |
+
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
| 766 |
+
"""
|
| 767 |
+
Search FAISS index for relevant documents
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
query: Search query
|
| 771 |
+
k: Number of results to return
|
| 772 |
+
|
| 773 |
+
Returns:
|
| 774 |
+
List of relevant documents with metadata
|
| 775 |
+
"""
|
| 776 |
+
if not self.index or not self.embedder:
|
| 777 |
+
raise ValueError("Index not built. Call build_faiss_index() first.")
|
| 778 |
+
|
| 779 |
+
# Encode query
|
| 780 |
+
query_embedding = self.embedder.encode([query])
|
| 781 |
+
query_embedding = np.array(query_embedding).astype('float32')
|
| 782 |
+
|
| 783 |
+
# Search
|
| 784 |
+
distances, indices = self.index.search(query_embedding, k)
|
| 785 |
+
|
| 786 |
+
# Return results
|
| 787 |
+
results = []
|
| 788 |
+
for i, idx in enumerate(indices[0]):
|
| 789 |
+
if idx < len(self.documents):
|
| 790 |
+
results.append({
|
| 791 |
+
"document": self.documents[idx],
|
| 792 |
+
"metadata": self.metadata[idx],
|
| 793 |
+
"distance": float(distances[0][i])
|
| 794 |
+
})
|
| 795 |
+
|
| 796 |
+
return results
|
| 797 |
+
|
| 798 |
+
def _build_cot_prompt(self, query: str, context: List[str]) -> str:
|
| 799 |
+
"""Build Chain-of-Thought prompt with step-by-step reasoning"""
|
| 800 |
+
context_text = "\n".join([f"- {ctx}" for ctx in context])
|
| 801 |
+
|
| 802 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to answer the question concisely.
|
| 803 |
+
|
| 804 |
+
CONTEXT:
|
| 805 |
+
{context_text}
|
| 806 |
+
|
| 807 |
+
QUESTION: {query}
|
| 808 |
+
|
| 809 |
+
Think step by step, then provide a brief answer:
|
| 810 |
+
|
| 811 |
+
REASONING:
|
| 812 |
+
1. Analyze available information
|
| 813 |
+
2. Identify key safety factors
|
| 814 |
+
3. Evaluate risks and prioritize
|
| 815 |
+
4. Conclude with recommendation
|
| 816 |
+
|
| 817 |
+
ANSWER:"""
|
| 818 |
+
return prompt
|
| 819 |
+
|
| 820 |
+
def _build_tot_prompt(self, query: str, context: List[str], thought: str = "") -> str:
|
| 821 |
+
"""Build Tree-of-Thoughts prompt for exploring multiple reasoning paths"""
|
| 822 |
+
context_text = "\n".join([f"- {ctx}" for ctx in context])
|
| 823 |
+
|
| 824 |
+
if not thought:
|
| 825 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to explore different reasoning approaches.
|
| 826 |
+
|
| 827 |
+
CONTEXT:
|
| 828 |
+
{context_text}
|
| 829 |
+
|
| 830 |
+
QUESTION: {query}
|
| 831 |
+
|
| 832 |
+
Let's explore different reasoning approaches to solve this problem:
|
| 833 |
+
|
| 834 |
+
APPROACH 1 - Safety-First Analysis:
|
| 835 |
+
"""
|
| 836 |
+
else:
|
| 837 |
+
prompt = f"""CONTEXT:
|
| 838 |
+
{context_text}
|
| 839 |
+
|
| 840 |
+
QUESTION: {query}
|
| 841 |
+
|
| 842 |
+
CURRENT THOUGHT: {thought}
|
| 843 |
+
|
| 844 |
+
Evaluate this thought:
|
| 845 |
+
- Is this reasoning sound?
|
| 846 |
+
- What are the strengths and weaknesses?
|
| 847 |
+
- What alternative approaches should we consider?
|
| 848 |
+
|
| 849 |
+
EVALUATION:
|
| 850 |
+
"""
|
| 851 |
+
return prompt
|
| 852 |
+
|
| 853 |
+
def _build_reflexion_prompt(self, query: str, context: List[str], previous_answer: str = "",
|
| 854 |
+
reflection: str = "") -> str:
|
| 855 |
+
"""Build Reflexion prompt for self-reflection and improvement"""
|
| 856 |
+
context_text = "\n".join([f"- {ctx}" for ctx in context])
|
| 857 |
+
|
| 858 |
+
if not previous_answer:
|
| 859 |
+
# Initial answer
|
| 860 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Use the following context to answer the question.
|
| 861 |
+
|
| 862 |
+
CONTEXT:
|
| 863 |
+
{context_text}
|
| 864 |
+
|
| 865 |
+
QUESTION: {query}
|
| 866 |
+
|
| 867 |
+
Provide a clear, safety-focused answer based on the context.
|
| 868 |
+
|
| 869 |
+
ANSWER:"""
|
| 870 |
+
else:
|
| 871 |
+
# Reflection phase
|
| 872 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Review and improve your previous answer.
|
| 873 |
+
|
| 874 |
+
CONTEXT:
|
| 875 |
+
{context_text}
|
| 876 |
+
|
| 877 |
+
QUESTION: {query}
|
| 878 |
+
|
| 879 |
+
PREVIOUS ANSWER:
|
| 880 |
+
{previous_answer}
|
| 881 |
+
|
| 882 |
+
REFLECTION:
|
| 883 |
+
{reflection}
|
| 884 |
+
|
| 885 |
+
Now provide an improved answer based on your reflection:
|
| 886 |
+
|
| 887 |
+
IMPROVED ANSWER:"""
|
| 888 |
+
return prompt
|
| 889 |
+
|
| 890 |
+
def _build_cot_with_tools_prompt(self, query: str, context: List[str], tool_results: List[str] = None) -> str:
|
| 891 |
+
"""Build Chain-of-Thought prompt with tool integration"""
|
| 892 |
+
context_text = "\n".join([f"- {ctx}" for ctx in context])
|
| 893 |
+
|
| 894 |
+
tool_text = ""
|
| 895 |
+
if tool_results:
|
| 896 |
+
tool_text = "\nTOOL RESULTS:\n" + "\n".join([f"- {result}" for result in tool_results])
|
| 897 |
+
|
| 898 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Use the following context and tool results to answer the question.
|
| 899 |
+
|
| 900 |
+
CONTEXT:
|
| 901 |
+
{context_text}
|
| 902 |
+
{tool_text}
|
| 903 |
+
|
| 904 |
+
QUESTION: {query}
|
| 905 |
+
|
| 906 |
+
Let's solve this step by step, using both the context and tool results:
|
| 907 |
+
|
| 908 |
+
STEP 1 - Understand the question and available data:
|
| 909 |
+
"""
|
| 910 |
+
return prompt
|
| 911 |
+
|
| 912 |
+
def _generate_with_decoding_strategy(self, prompt: str, max_length: int = 500,
|
| 913 |
+
temperature: float = 0.7, top_p: float = 0.9,
|
| 914 |
+
num_beams: int = 3, stop_sequences: List[str] = None) -> str:
|
| 915 |
+
"""Generate response using specified decoding strategy"""
|
| 916 |
+
if not self.pipe and not self.model:
|
| 917 |
+
raise ValueError("Model not loaded. Call download_model() first.")
|
| 918 |
+
|
| 919 |
+
try:
|
| 920 |
+
if self.use_unsloth and self.model:
|
| 921 |
+
inputs = self.tokenizer(
|
| 922 |
+
prompt,
|
| 923 |
+
return_tensors="pt",
|
| 924 |
+
truncation=True,
|
| 925 |
+
max_length=self.max_seq_length
|
| 926 |
+
).to(self.model.device)
|
| 927 |
+
|
| 928 |
+
# Configure generation parameters based on decoding strategy
|
| 929 |
+
gen_kwargs = {
|
| 930 |
+
"max_new_tokens": max_length,
|
| 931 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
| 932 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
if self.decoding_strategy == DecodingStrategy.GREEDY:
|
| 936 |
+
gen_kwargs.update({
|
| 937 |
+
"do_sample": False,
|
| 938 |
+
"num_beams": 1
|
| 939 |
+
})
|
| 940 |
+
elif self.decoding_strategy == DecodingStrategy.SAMPLING:
|
| 941 |
+
gen_kwargs.update({
|
| 942 |
+
"do_sample": True,
|
| 943 |
+
"temperature": temperature,
|
| 944 |
+
"top_k": 50
|
| 945 |
+
})
|
| 946 |
+
elif self.decoding_strategy == DecodingStrategy.BEAM_SEARCH:
|
| 947 |
+
gen_kwargs.update({
|
| 948 |
+
"do_sample": False,
|
| 949 |
+
"num_beams": num_beams,
|
| 950 |
+
"early_stopping": True
|
| 951 |
+
})
|
| 952 |
+
elif self.decoding_strategy == DecodingStrategy.NUCLEUS:
|
| 953 |
+
gen_kwargs.update({
|
| 954 |
+
"do_sample": True,
|
| 955 |
+
"temperature": temperature,
|
| 956 |
+
"top_p": top_p,
|
| 957 |
+
"top_k": 0
|
| 958 |
+
})
|
| 959 |
+
elif self.decoding_strategy == DecodingStrategy.TEMPERATURE:
|
| 960 |
+
gen_kwargs.update({
|
| 961 |
+
"do_sample": True,
|
| 962 |
+
"temperature": temperature
|
| 963 |
+
})
|
| 964 |
+
|
| 965 |
+
with torch.no_grad():
|
| 966 |
+
outputs = self.model.generate(**inputs, **gen_kwargs)
|
| 967 |
+
|
| 968 |
+
response = self.tokenizer.batch_decode(
|
| 969 |
+
outputs,
|
| 970 |
+
skip_special_tokens=True
|
| 971 |
+
)[0]
|
| 972 |
+
|
| 973 |
+
# Extract response after prompt
|
| 974 |
+
if prompt in response:
|
| 975 |
+
response = response.split(prompt)[-1].strip()
|
| 976 |
+
|
| 977 |
+
# Post-process to stop at verbose endings
|
| 978 |
+
stop_phrases = [
|
| 979 |
+
"\n\nHowever, please note",
|
| 980 |
+
"\n\nAdditionally,",
|
| 981 |
+
"\n\nLet me know",
|
| 982 |
+
"\n\nIf you have",
|
| 983 |
+
"\n\nHere's another",
|
| 984 |
+
"\n\nQUESTION:",
|
| 985 |
+
"\n\nLet's break",
|
| 986 |
+
"\n\nHave a great day",
|
| 987 |
+
"\n\nI'm here to help"
|
| 988 |
+
]
|
| 989 |
+
for phrase in stop_phrases:
|
| 990 |
+
if phrase in response:
|
| 991 |
+
response = response.split(phrase)[0].strip()
|
| 992 |
+
break
|
| 993 |
+
|
| 994 |
+
return response
|
| 995 |
+
else:
|
| 996 |
+
# Use pipeline for standard models
|
| 997 |
+
gen_kwargs = {
|
| 998 |
+
"max_length": len(self.tokenizer.encode(prompt)) + max_length,
|
| 999 |
+
"num_return_sequences": 1,
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
if self.decoding_strategy == DecodingStrategy.GREEDY:
|
| 1003 |
+
gen_kwargs.update({
|
| 1004 |
+
"do_sample": False
|
| 1005 |
+
})
|
| 1006 |
+
elif self.decoding_strategy == DecodingStrategy.SAMPLING:
|
| 1007 |
+
gen_kwargs.update({
|
| 1008 |
+
"do_sample": True,
|
| 1009 |
+
"temperature": temperature,
|
| 1010 |
+
"top_k": 50
|
| 1011 |
+
})
|
| 1012 |
+
elif self.decoding_strategy == DecodingStrategy.BEAM_SEARCH:
|
| 1013 |
+
gen_kwargs.update({
|
| 1014 |
+
"do_sample": False,
|
| 1015 |
+
"num_beams": num_beams,
|
| 1016 |
+
"early_stopping": True
|
| 1017 |
+
})
|
| 1018 |
+
elif self.decoding_strategy == DecodingStrategy.NUCLEUS:
|
| 1019 |
+
gen_kwargs.update({
|
| 1020 |
+
"do_sample": True,
|
| 1021 |
+
"temperature": temperature,
|
| 1022 |
+
"top_p": top_p,
|
| 1023 |
+
"top_k": 0
|
| 1024 |
+
})
|
| 1025 |
+
elif self.decoding_strategy == DecodingStrategy.TEMPERATURE:
|
| 1026 |
+
gen_kwargs.update({
|
| 1027 |
+
"do_sample": True,
|
| 1028 |
+
"temperature": temperature
|
| 1029 |
+
})
|
| 1030 |
+
|
| 1031 |
+
gen_kwargs["pad_token_id"] = self.tokenizer.eos_token_id if self.tokenizer else None
|
| 1032 |
+
|
| 1033 |
+
outputs = self.pipe(prompt, **gen_kwargs)
|
| 1034 |
+
response = outputs[0]['generated_text']
|
| 1035 |
+
|
| 1036 |
+
# Extract response after prompt
|
| 1037 |
+
if prompt in response:
|
| 1038 |
+
response = response.split(prompt)[-1].strip()
|
| 1039 |
+
|
| 1040 |
+
# Post-process to stop at verbose endings
|
| 1041 |
+
stop_phrases = [
|
| 1042 |
+
"\n\nHowever, please note",
|
| 1043 |
+
"\n\nAdditionally,",
|
| 1044 |
+
"\n\nLet me know",
|
| 1045 |
+
"\n\nIf you have",
|
| 1046 |
+
"\n\nHere's another",
|
| 1047 |
+
"\n\nQUESTION:",
|
| 1048 |
+
"\n\nLet's break",
|
| 1049 |
+
"\n\nHave a great day",
|
| 1050 |
+
"\n\nI'm here to help"
|
| 1051 |
+
]
|
| 1052 |
+
for phrase in stop_phrases:
|
| 1053 |
+
if phrase in response:
|
| 1054 |
+
response = response.split(phrase)[0].strip()
|
| 1055 |
+
break
|
| 1056 |
+
|
| 1057 |
+
return response
|
| 1058 |
+
|
| 1059 |
+
except Exception as e:
|
| 1060 |
+
return f"Error generating response: {e}"
|
| 1061 |
+
|
| 1062 |
+
def _chain_of_thought_reasoning(self, query: str, context: List[str], max_length: int = 500) -> Tuple[str, str]:
|
| 1063 |
+
"""Generate response using Chain-of-Thought reasoning
|
| 1064 |
+
|
| 1065 |
+
Returns:
|
| 1066 |
+
Tuple of (full_reasoning, final_answer)
|
| 1067 |
+
"""
|
| 1068 |
+
prompt = self._build_cot_prompt(query, context)
|
| 1069 |
+
# Use shorter max_length for CoT to prevent verbosity
|
| 1070 |
+
full_response = self._generate_with_decoding_strategy(prompt, max_length=min(max_length, 300))
|
| 1071 |
+
|
| 1072 |
+
# Extract reasoning steps (everything before ANSWER)
|
| 1073 |
+
reasoning = ""
|
| 1074 |
+
if "REASONING:" in full_response:
|
| 1075 |
+
reasoning_parts = full_response.split("REASONING:")
|
| 1076 |
+
if len(reasoning_parts) > 1:
|
| 1077 |
+
reasoning_section = reasoning_parts[1].split("ANSWER:")[0] if "ANSWER:" in reasoning_parts[1] else reasoning_parts[1]
|
| 1078 |
+
reasoning = reasoning_section.strip()
|
| 1079 |
+
elif "ANSWER:" in full_response:
|
| 1080 |
+
reasoning = full_response.split("ANSWER:")[0].strip()
|
| 1081 |
+
else:
|
| 1082 |
+
# Try to extract reasoning from numbered steps
|
| 1083 |
+
lines = full_response.split('\n')
|
| 1084 |
+
reasoning_lines = []
|
| 1085 |
+
for line in lines:
|
| 1086 |
+
if line.strip().startswith(('1.', '2.', '3.', '4.', '5.', 'Step', 'STEP')):
|
| 1087 |
+
reasoning_lines.append(line.strip())
|
| 1088 |
+
elif "ANSWER" in line.upper():
|
| 1089 |
+
break
|
| 1090 |
+
elif reasoning_lines: # Continue collecting if we've started
|
| 1091 |
+
reasoning_lines.append(line.strip())
|
| 1092 |
+
reasoning = '\n'.join(reasoning_lines)
|
| 1093 |
+
|
| 1094 |
+
# Extract final answer (everything after ANSWER:)
|
| 1095 |
+
final_answer = full_response
|
| 1096 |
+
if "ANSWER:" in full_response:
|
| 1097 |
+
answer_parts = full_response.split("ANSWER:")
|
| 1098 |
+
if len(answer_parts) > 1:
|
| 1099 |
+
answer_text = answer_parts[-1].strip()
|
| 1100 |
+
# Stop at common continuation markers
|
| 1101 |
+
stop_markers = [
|
| 1102 |
+
"\n\nHowever, please note",
|
| 1103 |
+
"\n\nAdditionally,",
|
| 1104 |
+
"\n\nLet me know",
|
| 1105 |
+
"\n\nIf you have",
|
| 1106 |
+
"\n\nHere's another",
|
| 1107 |
+
"\n\nQUESTION:",
|
| 1108 |
+
"\n\nLet's break",
|
| 1109 |
+
"\n\nHave a great day",
|
| 1110 |
+
"\n\nI'm here to help",
|
| 1111 |
+
"\n\nThese general guidelines",
|
| 1112 |
+
"\n\nIf you have any further"
|
| 1113 |
+
]
|
| 1114 |
+
for marker in stop_markers:
|
| 1115 |
+
if marker in answer_text:
|
| 1116 |
+
answer_text = answer_text.split(marker)[0].strip()
|
| 1117 |
+
break
|
| 1118 |
+
# Also limit to first 2-3 sentences if it's still too long
|
| 1119 |
+
sentences = answer_text.split('. ')
|
| 1120 |
+
if len(sentences) > 3:
|
| 1121 |
+
answer_text = '. '.join(sentences[:3])
|
| 1122 |
+
if not answer_text.endswith('.'):
|
| 1123 |
+
answer_text += '.'
|
| 1124 |
+
final_answer = answer_text
|
| 1125 |
+
|
| 1126 |
+
# Clean up reasoning - remove verbose parts
|
| 1127 |
+
if reasoning:
|
| 1128 |
+
# Remove common verbose endings
|
| 1129 |
+
verbose_endings = [
|
| 1130 |
+
"However, please note",
|
| 1131 |
+
"Additionally,",
|
| 1132 |
+
"Let me know",
|
| 1133 |
+
"If you have",
|
| 1134 |
+
"Here's another",
|
| 1135 |
+
"Have a great day",
|
| 1136 |
+
"I'm here to help"
|
| 1137 |
+
]
|
| 1138 |
+
for ending in verbose_endings:
|
| 1139 |
+
if ending in reasoning:
|
| 1140 |
+
reasoning = reasoning.split(ending)[0].strip()
|
| 1141 |
+
break
|
| 1142 |
+
|
| 1143 |
+
return reasoning or "Reasoning steps generated", final_answer
|
| 1144 |
+
|
| 1145 |
+
def _tree_of_thoughts_reasoning(self, query: str, context: List[str], max_length: int = 500,
|
| 1146 |
+
max_thoughts: int = 3) -> Tuple[str, str]:
|
| 1147 |
+
"""Generate response using Tree-of-Thoughts reasoning
|
| 1148 |
+
|
| 1149 |
+
Returns:
|
| 1150 |
+
Tuple of (full_reasoning, final_answer)
|
| 1151 |
+
"""
|
| 1152 |
+
thoughts = []
|
| 1153 |
+
reasoning_log = []
|
| 1154 |
+
|
| 1155 |
+
# Generate initial thoughts
|
| 1156 |
+
for i in range(max_thoughts):
|
| 1157 |
+
thought_prompt = self._build_tot_prompt(query, context,
|
| 1158 |
+
thought=f"Exploring approach {i+1}")
|
| 1159 |
+
thought = self._generate_with_decoding_strategy(thought_prompt, max_length // max_thoughts)
|
| 1160 |
+
thoughts.append(thought)
|
| 1161 |
+
reasoning_log.append(f"APPROACH {i+1}:\n{thought}\n")
|
| 1162 |
+
|
| 1163 |
+
# Evaluate thoughts and select best
|
| 1164 |
+
evaluation_prompt = f"""Evaluate these different reasoning approaches for answering the question:
|
| 1165 |
+
|
| 1166 |
+
QUESTION: {query}
|
| 1167 |
+
|
| 1168 |
+
APPROACHES:
|
| 1169 |
+
"""
|
| 1170 |
+
for i, thought in enumerate(thoughts, 1):
|
| 1171 |
+
evaluation_prompt += f"\nAPPROACH {i}:\n{thought}\n"
|
| 1172 |
+
|
| 1173 |
+
evaluation_prompt += "\nWhich approach is most sound and complete? Provide the best answer based on the evaluation.\n\nBEST ANSWER:"
|
| 1174 |
+
|
| 1175 |
+
final_response = self._generate_with_decoding_strategy(evaluation_prompt, max_length)
|
| 1176 |
+
|
| 1177 |
+
full_reasoning = "\n".join(reasoning_log) + f"\n\nEVALUATION:\n{final_response}"
|
| 1178 |
+
return full_reasoning, final_response
|
| 1179 |
+
|
| 1180 |
+
def _reflexion_reasoning(self, query: str, context: List[str], max_length: int = 500,
|
| 1181 |
+
max_iterations: int = 2) -> Tuple[str, str]:
|
| 1182 |
+
"""Generate response using Reflexion (self-reflection and improvement)
|
| 1183 |
+
|
| 1184 |
+
Returns:
|
| 1185 |
+
Tuple of (full_reasoning, final_answer)
|
| 1186 |
+
"""
|
| 1187 |
+
reasoning_log = []
|
| 1188 |
+
|
| 1189 |
+
# Initial answer
|
| 1190 |
+
initial_prompt = self._build_reflexion_prompt(query, context)
|
| 1191 |
+
answer = self._generate_with_decoding_strategy(initial_prompt, max_length)
|
| 1192 |
+
reasoning_log.append(f"INITIAL ANSWER:\n{answer}\n")
|
| 1193 |
+
|
| 1194 |
+
# Reflection and improvement iterations
|
| 1195 |
+
for iteration in range(max_iterations):
|
| 1196 |
+
# Generate reflection
|
| 1197 |
+
reflection_prompt = f"""Review this answer for a fire evacuation safety question:
|
| 1198 |
+
|
| 1199 |
+
QUESTION: {query}
|
| 1200 |
+
|
| 1201 |
+
CURRENT ANSWER:
|
| 1202 |
+
{answer}
|
| 1203 |
+
|
| 1204 |
+
What could be improved? Consider:
|
| 1205 |
+
- Accuracy of safety information
|
| 1206 |
+
- Completeness of the response
|
| 1207 |
+
- Clarity and actionability
|
| 1208 |
+
- Missing critical safety factors
|
| 1209 |
+
|
| 1210 |
+
REFLECTION:"""
|
| 1211 |
+
|
| 1212 |
+
reflection = self._generate_with_decoding_strategy(reflection_prompt, max_length // 2)
|
| 1213 |
+
reasoning_log.append(f"ITERATION {iteration + 1} - REFLECTION:\n{reflection}\n")
|
| 1214 |
+
|
| 1215 |
+
# Generate improved answer
|
| 1216 |
+
improved_prompt = self._build_reflexion_prompt(query, context, answer, reflection)
|
| 1217 |
+
improved_answer = self._generate_with_decoding_strategy(improved_prompt, max_length)
|
| 1218 |
+
reasoning_log.append(f"ITERATION {iteration + 1} - IMPROVED ANSWER:\n{improved_answer}\n")
|
| 1219 |
+
|
| 1220 |
+
# Check if improvement is significant (simple heuristic)
|
| 1221 |
+
if len(improved_answer) > len(answer) * 0.8: # At least 80% of original length
|
| 1222 |
+
answer = improved_answer
|
| 1223 |
+
else:
|
| 1224 |
+
break # Stop if answer becomes too short
|
| 1225 |
+
|
| 1226 |
+
self.reflexion_history.append({
|
| 1227 |
+
"query": query,
|
| 1228 |
+
"final_answer": answer,
|
| 1229 |
+
"iterations": iteration + 1
|
| 1230 |
+
})
|
| 1231 |
+
|
| 1232 |
+
full_reasoning = "\n".join(reasoning_log)
|
| 1233 |
+
return full_reasoning, answer
|
| 1234 |
+
|
| 1235 |
+
def _cot_with_tools_reasoning(self, query: str, context: List[str], max_length: int = 500) -> Tuple[str, str]:
|
| 1236 |
+
"""Generate response using Chain-of-Thought with tool integration
|
| 1237 |
+
|
| 1238 |
+
Returns:
|
| 1239 |
+
Tuple of (full_reasoning, final_answer)
|
| 1240 |
+
"""
|
| 1241 |
+
reasoning_log = []
|
| 1242 |
+
|
| 1243 |
+
# Simulate tool calls (in real implementation, these would call actual tools)
|
| 1244 |
+
tool_results = []
|
| 1245 |
+
|
| 1246 |
+
# Tool 1: Route analysis
|
| 1247 |
+
if "route" in query.lower() or "path" in query.lower():
|
| 1248 |
+
tool_result = "Tool: Route Analyzer - Found 3 evacuation routes with risk scores"
|
| 1249 |
+
tool_results.append(tool_result)
|
| 1250 |
+
reasoning_log.append(f"TOOL CALL: {tool_result}\n")
|
| 1251 |
+
|
| 1252 |
+
# Tool 2: Risk calculator
|
| 1253 |
+
if "danger" in query.lower() or "risk" in query.lower():
|
| 1254 |
+
tool_result = "Tool: Risk Calculator - Calculated danger scores for all rooms"
|
| 1255 |
+
tool_results.append(tool_result)
|
| 1256 |
+
reasoning_log.append(f"TOOL CALL: {tool_result}\n")
|
| 1257 |
+
|
| 1258 |
+
# Tool 3: Sensor aggregator
|
| 1259 |
+
if "sensor" in query.lower() or "temperature" in query.lower() or "smoke" in query.lower():
|
| 1260 |
+
tool_result = "Tool: Sensor Aggregator - Aggregated sensor data from all rooms"
|
| 1261 |
+
tool_results.append(tool_result)
|
| 1262 |
+
reasoning_log.append(f"TOOL CALL: {tool_result}\n")
|
| 1263 |
+
|
| 1264 |
+
prompt = self._build_cot_with_tools_prompt(query, context, tool_results)
|
| 1265 |
+
response = self._generate_with_decoding_strategy(prompt, max_length)
|
| 1266 |
+
|
| 1267 |
+
reasoning_log.append(f"REASONING WITH TOOLS:\n{response}\n")
|
| 1268 |
+
full_reasoning = "\n".join(reasoning_log)
|
| 1269 |
+
|
| 1270 |
+
# Extract final answer
|
| 1271 |
+
final_answer = response
|
| 1272 |
+
if "ANSWER:" in response or "answer:" in response.lower():
|
| 1273 |
+
parts = response.split("ANSWER:") if "ANSWER:" in response else response.split("answer:")
|
| 1274 |
+
if len(parts) > 1:
|
| 1275 |
+
final_answer = parts[-1].strip()
|
| 1276 |
+
|
| 1277 |
+
return full_reasoning, final_answer
|
| 1278 |
+
|
| 1279 |
+
def generate_response(self, query: str, context: List[str] = None, max_length: int = 500,
|
| 1280 |
+
return_reasoning: bool = False) -> str:
|
| 1281 |
+
"""
|
| 1282 |
+
Generate response using Llama model with context and advanced reasoning
|
| 1283 |
+
|
| 1284 |
+
Args:
|
| 1285 |
+
query: User query
|
| 1286 |
+
context: Optional context strings (if None, will retrieve from FAISS)
|
| 1287 |
+
max_length: Maximum response length
|
| 1288 |
+
return_reasoning: If True, returns tuple of (reasoning, answer), else just answer
|
| 1289 |
+
|
| 1290 |
+
Returns:
|
| 1291 |
+
If return_reasoning is True: Tuple of (reasoning_steps, final_answer)
|
| 1292 |
+
Otherwise: Just the final answer string
|
| 1293 |
+
"""
|
| 1294 |
+
if not self.pipe and not self.model:
|
| 1295 |
+
raise ValueError("Model not loaded. Call download_model() first.")
|
| 1296 |
+
|
| 1297 |
+
# Retrieve context if not provided
|
| 1298 |
+
if context is None:
|
| 1299 |
+
search_results = self.search(query, k=3)
|
| 1300 |
+
context = [r["document"] for r in search_results]
|
| 1301 |
+
|
| 1302 |
+
# Route to appropriate reasoning method based on mode
|
| 1303 |
+
if self.reasoning_mode == ReasoningMode.CHAIN_OF_THOUGHT:
|
| 1304 |
+
reasoning, answer = self._chain_of_thought_reasoning(query, context, max_length)
|
| 1305 |
+
elif self.reasoning_mode == ReasoningMode.TREE_OF_THOUGHTS:
|
| 1306 |
+
reasoning, answer = self._tree_of_thoughts_reasoning(query, context, max_length)
|
| 1307 |
+
elif self.reasoning_mode == ReasoningMode.REFLEXION:
|
| 1308 |
+
reasoning, answer = self._reflexion_reasoning(query, context, max_length)
|
| 1309 |
+
elif self.reasoning_mode == ReasoningMode.COT_WITH_TOOLS:
|
| 1310 |
+
reasoning, answer = self._cot_with_tools_reasoning(query, context, max_length)
|
| 1311 |
+
else:
|
| 1312 |
+
# Standard mode - use enhanced prompt with decoding strategy
|
| 1313 |
+
context_text = "\n".join([f"- {ctx}" for ctx in context])
|
| 1314 |
+
|
| 1315 |
+
prompt = f"""You are an expert fire evacuation safety advisor. Use the following context about the building's fire safety status to answer the question.
|
| 1316 |
+
|
| 1317 |
+
CONTEXT:
|
| 1318 |
+
{context_text}
|
| 1319 |
+
|
| 1320 |
+
QUESTION: {query}
|
| 1321 |
+
|
| 1322 |
+
Provide a clear, safety-focused answer based on the context. If the context doesn't contain enough information, say so.
|
| 1323 |
+
|
| 1324 |
+
ANSWER:"""
|
| 1325 |
+
|
| 1326 |
+
answer = self._generate_with_decoding_strategy(prompt, max_length)
|
| 1327 |
+
reasoning = f"Standard reasoning mode - Direct answer generation.\n\n{answer}"
|
| 1328 |
+
|
| 1329 |
+
if return_reasoning:
|
| 1330 |
+
return reasoning, answer
|
| 1331 |
+
return answer
|
| 1332 |
+
|
| 1333 |
+
def set_reasoning_mode(self, mode: ReasoningMode):
|
| 1334 |
+
"""Set the reasoning mode for future queries"""
|
| 1335 |
+
self.reasoning_mode = mode
|
| 1336 |
+
print(f"[OK] Reasoning mode set to: {mode.value}")
|
| 1337 |
+
|
| 1338 |
+
def set_decoding_strategy(self, strategy: DecodingStrategy):
|
| 1339 |
+
"""Set the decoding strategy for future queries"""
|
| 1340 |
+
self.decoding_strategy = strategy
|
| 1341 |
+
print(f"[OK] Decoding strategy set to: {strategy.value}")
|
| 1342 |
+
|
| 1343 |
+
def query(self, question: str, k: int = 3, reasoning_mode: Optional[ReasoningMode] = None,
|
| 1344 |
+
show_reasoning: bool = True) -> Dict[str, Any]:
|
| 1345 |
+
"""
|
| 1346 |
+
Complete RAG query: retrieve context and generate response with advanced reasoning
|
| 1347 |
+
|
| 1348 |
+
Args:
|
| 1349 |
+
question: User question
|
| 1350 |
+
k: Number of context documents to retrieve
|
| 1351 |
+
reasoning_mode: Optional override for reasoning mode (uses instance default if None)
|
| 1352 |
+
show_reasoning: If True, includes full reasoning steps in response
|
| 1353 |
+
|
| 1354 |
+
Returns:
|
| 1355 |
+
Dictionary with answer, context, metadata, reasoning information, and reasoning steps
|
| 1356 |
+
"""
|
| 1357 |
+
# Retrieve relevant context
|
| 1358 |
+
search_results = self.search(question, k=k)
|
| 1359 |
+
|
| 1360 |
+
# Generate response with reasoning
|
| 1361 |
+
context = [r["document"] for r in search_results]
|
| 1362 |
+
|
| 1363 |
+
# Temporarily override reasoning mode if provided
|
| 1364 |
+
original_mode = self.reasoning_mode
|
| 1365 |
+
if reasoning_mode is not None:
|
| 1366 |
+
self.reasoning_mode = reasoning_mode
|
| 1367 |
+
|
| 1368 |
+
try:
|
| 1369 |
+
reasoning, answer = self.generate_response(question, context, return_reasoning=True)
|
| 1370 |
+
finally:
|
| 1371 |
+
# Restore original mode
|
| 1372 |
+
self.reasoning_mode = original_mode
|
| 1373 |
+
|
| 1374 |
+
result = {
|
| 1375 |
+
"question": question,
|
| 1376 |
+
"answer": answer,
|
| 1377 |
+
"context": context,
|
| 1378 |
+
"reasoning_mode": self.reasoning_mode.value,
|
| 1379 |
+
"decoding_strategy": self.decoding_strategy.value,
|
| 1380 |
+
"sources": [
|
| 1381 |
+
{
|
| 1382 |
+
"type": r["metadata"].get("type"),
|
| 1383 |
+
"room_id": r["metadata"].get("room_id"),
|
| 1384 |
+
"route_id": r["metadata"].get("route_id"),
|
| 1385 |
+
"relevance_score": 1.0 / (1.0 + r["distance"])
|
| 1386 |
+
}
|
| 1387 |
+
for r in search_results
|
| 1388 |
+
]
|
| 1389 |
+
}
|
| 1390 |
+
|
| 1391 |
+
if show_reasoning:
|
| 1392 |
+
result["reasoning_steps"] = reasoning
|
| 1393 |
+
|
| 1394 |
+
return result
|
| 1395 |
+
|
| 1396 |
+
def save_index(self, index_path: str, metadata_path: str):
|
| 1397 |
+
"""Save FAISS index and metadata"""
|
| 1398 |
+
if self.index:
|
| 1399 |
+
faiss.write_index(self.index, index_path)
|
| 1400 |
+
with open(metadata_path, 'wb') as f:
|
| 1401 |
+
pickle.dump({
|
| 1402 |
+
"documents": self.documents,
|
| 1403 |
+
"metadata": self.metadata
|
| 1404 |
+
}, f)
|
| 1405 |
+
print(f"[OK] Saved index to {index_path} and metadata to {metadata_path}")
|
| 1406 |
+
|
| 1407 |
+
def load_index(self, index_path: str, metadata_path: str):
|
| 1408 |
+
"""Load FAISS index and metadata"""
|
| 1409 |
+
self.index = faiss.read_index(index_path)
|
| 1410 |
+
with open(metadata_path, 'rb') as f:
|
| 1411 |
+
data = pickle.load(f)
|
| 1412 |
+
self.documents = data["documents"]
|
| 1413 |
+
self.metadata = data["metadata"]
|
| 1414 |
+
print(f"[OK] Loaded index with {self.index.ntotal} vectors")
|
| 1415 |
+
|
| 1416 |
+
def compare_reasoning_modes(self, question: str, k: int = 3) -> Dict[str, Any]:
|
| 1417 |
+
"""
|
| 1418 |
+
Compare all reasoning modes for a given question
|
| 1419 |
+
|
| 1420 |
+
Args:
|
| 1421 |
+
question: User question
|
| 1422 |
+
k: Number of context documents to retrieve
|
| 1423 |
+
|
| 1424 |
+
Returns:
|
| 1425 |
+
Dictionary with answers from all reasoning modes
|
| 1426 |
+
"""
|
| 1427 |
+
# Retrieve context once
|
| 1428 |
+
search_results = self.search(question, k=k)
|
| 1429 |
+
context = [r["document"] for r in search_results]
|
| 1430 |
+
|
| 1431 |
+
results = {
|
| 1432 |
+
"question": question,
|
| 1433 |
+
"context": context,
|
| 1434 |
+
"sources": [
|
| 1435 |
+
{
|
| 1436 |
+
"type": r["metadata"].get("type"),
|
| 1437 |
+
"room_id": r["metadata"].get("room_id"),
|
| 1438 |
+
"route_id": r["metadata"].get("route_id"),
|
| 1439 |
+
"relevance_score": 1.0 / (1.0 + r["distance"])
|
| 1440 |
+
}
|
| 1441 |
+
for r in search_results
|
| 1442 |
+
],
|
| 1443 |
+
"answers": {}
|
| 1444 |
+
}
|
| 1445 |
+
|
| 1446 |
+
# Save original mode
|
| 1447 |
+
original_mode = self.reasoning_mode
|
| 1448 |
+
|
| 1449 |
+
# Test each reasoning mode
|
| 1450 |
+
for mode in ReasoningMode:
|
| 1451 |
+
try:
|
| 1452 |
+
self.reasoning_mode = mode
|
| 1453 |
+
reasoning, answer = self.generate_response(question, context, return_reasoning=True)
|
| 1454 |
+
results["answers"][mode.value] = {
|
| 1455 |
+
"answer": answer,
|
| 1456 |
+
"reasoning": reasoning,
|
| 1457 |
+
"length": len(answer)
|
| 1458 |
+
}
|
| 1459 |
+
except Exception as e:
|
| 1460 |
+
results["answers"][mode.value] = {
|
| 1461 |
+
"error": str(e)
|
| 1462 |
+
}
|
| 1463 |
+
|
| 1464 |
+
# Restore original mode
|
| 1465 |
+
self.reasoning_mode = original_mode
|
| 1466 |
+
|
| 1467 |
+
return results
|
| 1468 |
+
|
| 1469 |
+
# === Gradio integration ===
|
| 1470 |
+
_rag_instance: Optional[FireEvacuationRAG] = None
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
def _init_rag() -> FireEvacuationRAG:
|
| 1474 |
+
"""Initialize and cache the RAG system for Gradio use."""
|
| 1475 |
+
global _rag_instance
|
| 1476 |
+
if _rag_instance is not None:
|
| 1477 |
+
return _rag_instance
|
| 1478 |
+
|
| 1479 |
+
# Configuration (match original defaults, but without noisy prints)
|
| 1480 |
+
USE_UNSLOTH = True
|
| 1481 |
+
USE_8BIT = False
|
| 1482 |
+
UNSLOTH_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct"
|
| 1483 |
+
|
| 1484 |
+
# Set model directory to absolute path
|
| 1485 |
+
MODEL_DIR = r"D:\github\cse499\models"
|
| 1486 |
+
|
| 1487 |
+
# Create fire evacuation system
|
| 1488 |
+
floor_plan = create_sample_floor_plan()
|
| 1489 |
+
sensor_system = create_sample_fire_scenario(floor_plan)
|
| 1490 |
+
pathfinder = PathFinder(floor_plan, sensor_system)
|
| 1491 |
+
|
| 1492 |
+
# Export data and build index
|
| 1493 |
+
exporter = FireEvacuationDataExporter(floor_plan, sensor_system, pathfinder)
|
| 1494 |
+
json_data = exporter.export_to_json("fire_evacuation_data.json", start_location="R1")
|
| 1495 |
+
|
| 1496 |
+
# Initialize RAG
|
| 1497 |
+
if USE_UNSLOTH:
|
| 1498 |
+
rag = FireEvacuationRAG(
|
| 1499 |
+
model_name=UNSLOTH_MODEL,
|
| 1500 |
+
model_dir=MODEL_DIR,
|
| 1501 |
+
use_unsloth=True,
|
| 1502 |
+
load_in_4bit=False,
|
| 1503 |
+
max_seq_length=2048,
|
| 1504 |
+
reasoning_mode=ReasoningMode.CHAIN_OF_THOUGHT,
|
| 1505 |
+
decoding_strategy=DecodingStrategy.NUCLEUS,
|
| 1506 |
+
)
|
| 1507 |
+
else:
|
| 1508 |
+
rag = FireEvacuationRAG(
|
| 1509 |
+
model_name="nvidia/Llama-3.1-Minitron-4B-Width-Base",
|
| 1510 |
+
model_dir=MODEL_DIR,
|
| 1511 |
+
use_8bit=USE_8BIT,
|
| 1512 |
+
reasoning_mode=ReasoningMode.CHAIN_OF_THOUGHT,
|
| 1513 |
+
decoding_strategy=DecodingStrategy.NUCLEUS,
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
rag.download_model()
|
| 1517 |
+
rag.load_embedder()
|
| 1518 |
+
rag.build_index_from_json(json_data)
|
| 1519 |
+
|
| 1520 |
+
_rag_instance = rag
|
| 1521 |
+
return rag
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
def gradio_answer(question: str) -> str:
|
| 1525 |
+
"""Gradio callback: take a text question, return LLM/RAG answer."""
|
| 1526 |
+
question = (question or "").strip()
|
| 1527 |
+
if not question:
|
| 1528 |
+
return "Please enter a question about fire evacuation or building safety."
|
| 1529 |
+
|
| 1530 |
+
rag = _init_rag()
|
| 1531 |
+
result = rag.query(question, k=3, show_reasoning=False)
|
| 1532 |
+
return result.get("answer", "No answer generated.")
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
if __name__ == "__main__":
|
| 1536 |
+
iface = gr.Interface(
|
| 1537 |
+
fn=gradio_answer,
|
| 1538 |
+
inputs=gr.Textbox(lines=3, label="Fire Evacuation Question"),
|
| 1539 |
+
outputs=gr.Textbox(lines=6, label="LLM Recommendation"),
|
| 1540 |
+
title="Fire Evacuation RAG Advisor",
|
| 1541 |
+
description="Ask about evacuation routes, dangers, and exits in the simulated building.",
|
| 1542 |
+
)
|
| 1543 |
+
iface.launch()
|
| 1544 |
+
|