Local AI Assistant commited on
Commit
d23996b
·
1 Parent(s): 4133eb6

Update backend deployment with Firebase and RAG

Browse files
Files changed (3) hide show
  1. .gitignore +0 -0
  2. api.py +176 -200
  3. requirements.txt +1 -0
.gitignore ADDED
Binary file (443 Bytes). View file
 
api.py CHANGED
@@ -16,43 +16,31 @@ from chat_engine import ChatEngine
16
  from image_engine import ImageEngine
17
  import models
18
  import schemas
19
- from database import SessionLocal, engine
 
20
 
21
- # Create tables
22
- models.Base.metadata.create_all(bind=engine)
23
-
24
- app = FastAPI()
25
- # Force git update
26
-
27
- # Security Config
28
- SECRET_KEY = "your-secret-key-keep-it-secret" # In production, use env var
29
- ALGORITHM = "HS256"
30
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
31
-
32
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
33
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
34
-
35
- # Enable CORS
36
- app.add_middleware(
37
- CORSMiddleware,
38
- allow_origins=["*"],
39
- allow_credentials=True,
40
- allow_methods=["*"],
41
- allow_headers=["*"],
42
- )
43
-
44
- from fastapi.responses import JSONResponse
45
-
46
- @app.exception_handler(Exception)
47
- async def global_exception_handler(request, exc):
48
- return JSONResponse(
49
- status_code=500,
50
- content={"detail": f"Internal Server Error: {str(exc)}"},
51
- )
52
 
53
- from fastapi import UploadFile, File
54
- import shutil
55
- from rag_engine import RAGEngine
 
56
 
57
  # Initialize engines
58
  print("Initializing AI Engines...")
@@ -61,165 +49,144 @@ image_engine = ImageEngine()
61
  rag_engine = RAGEngine()
62
  print("AI Engines Ready!")
63
 
64
- # Dependency
65
- def get_db():
66
- db = SessionLocal()
67
- try:
68
- yield db
69
- finally:
70
- db.close()
71
-
72
- # Auth Helpers
73
- def verify_password(plain_password, hashed_password):
74
- if len(plain_password) > 72:
75
- plain_password = plain_password[:72]
76
- return pwd_context.verify(plain_password, hashed_password)
77
-
78
- def get_password_hash(password):
79
- if len(password) > 72:
80
- password = password[:72]
81
- return pwd_context.hash(password)
82
-
83
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
84
- to_encode = data.copy()
85
- if expires_delta:
86
- expire = datetime.utcnow() + expires_delta
87
- else:
88
- expire = datetime.utcnow() + timedelta(minutes=15)
89
- to_encode.update({"exp": expire})
90
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
91
- return encoded_jwt
92
 
93
- async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
94
- credentials_exception = HTTPException(
95
- status_code=status.HTTP_401_UNAUTHORIZED,
96
- detail="Could not validate credentials",
97
- headers={"WWW-Authenticate": "Bearer"},
98
- )
99
  try:
100
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
101
- email: str = payload.get("sub")
102
- if email is None:
103
- raise credentials_exception
104
- token_data = schemas.TokenData(email=email)
105
- except JWTError:
106
- raise credentials_exception
107
- user = db.query(models.User).filter(models.User.email == token_data.email).first()
108
- if user is None:
109
- raise credentials_exception
110
- return user
111
-
112
- async def get_current_admin(current_user: models.User = Depends(get_current_user)):
113
- if not current_user.is_admin:
114
- raise HTTPException(status_code=403, detail="Not authorized")
115
- return current_user
116
-
117
- # Auth Endpoints
118
- @app.post("/register", response_model=schemas.User)
119
- def register(user: schemas.UserCreate, db: Session = Depends(get_db)):
120
- db_user = db.query(models.User).filter(models.User.email == user.email).first()
121
- if db_user:
122
- raise HTTPException(status_code=400, detail="Email already registered")
123
-
124
- hashed_password = get_password_hash(user.password)
125
-
126
- # Check if this is the Admin user
127
- is_admin = False
128
- if user.email == "[email protected]":
129
- is_admin = True
130
 
131
- db_user = models.User(
132
- email=user.email,
133
- hashed_password=hashed_password,
134
- full_name=user.full_name,
135
- company_name=user.company_name,
136
- is_admin=is_admin
137
- )
138
- db.add(db_user)
139
- db.commit()
140
- db.refresh(db_user)
141
- return db_user
142
-
143
- @app.post("/token", response_model=schemas.Token)
144
- async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
145
- user = db.query(models.User).filter(models.User.email == form_data.username).first()
146
- if not user or not verify_password(form_data.password, user.hashed_password):
147
  raise HTTPException(
148
  status_code=status.HTTP_401_UNAUTHORIZED,
149
- detail="Incorrect username or password",
150
  headers={"WWW-Authenticate": "Bearer"},
151
  )
152
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
153
- access_token = create_access_token(
154
- data={"sub": user.email}, expires_delta=access_token_expires
155
- )
156
- return {"access_token": access_token, "token_type": "bearer"}
157
 
158
- @app.get("/users/me", response_model=schemas.User)
159
- async def read_users_me(current_user: schemas.User = Depends(get_current_user)):
 
 
 
 
 
 
 
 
 
160
  return current_user
161
 
162
  # Conversation Endpoints
163
- @app.post("/conversations", response_model=schemas.Conversation)
164
- async def create_conversation(conversation: schemas.ConversationCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
165
- db_conversation = models.Conversation(**conversation.dict(), user_id=current_user.id)
166
- db.add(db_conversation)
167
- db.commit()
168
- db.refresh(db_conversation)
169
- return db_conversation
 
 
 
 
 
 
 
 
170
 
171
- @app.get("/conversations", response_model=List[schemas.Conversation])
172
- async def get_conversations(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
173
- return db.query(models.Conversation).filter(models.Conversation.user_id == current_user.id).order_by(models.Conversation.updated_at.desc()).all()
 
 
 
 
174
 
175
- @app.get("/conversations/{conversation_id}/messages", response_model=List[schemas.ChatMessage])
176
- async def get_conversation_messages(conversation_id: int, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
177
- conversation = db.query(models.Conversation).filter(models.Conversation.id == conversation_id, models.Conversation.user_id == current_user.id).first()
178
- if not conversation:
179
- raise HTTPException(status_code=404, detail="Conversation not found")
180
- return db.query(models.ChatMessage).filter(models.ChatMessage.conversation_id == conversation_id).order_by(models.ChatMessage.timestamp).all()
 
 
 
 
 
 
 
181
 
182
  # Saved Prompt Endpoints
183
- @app.post("/prompts", response_model=schemas.SavedPrompt)
184
- async def create_prompt(prompt: schemas.SavedPromptCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
185
- db_prompt = models.SavedPrompt(**prompt.dict(), user_id=current_user.id)
186
- db.add(db_prompt)
187
- db.commit()
188
- db.refresh(db_prompt)
189
- return db_prompt
 
 
 
 
 
 
 
 
 
190
 
191
- @app.get("/prompts", response_model=List[schemas.SavedPrompt])
192
- async def get_prompts(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
193
- return db.query(models.SavedPrompt).filter(models.SavedPrompt.user_id == current_user.id).order_by(models.SavedPrompt.created_at.desc()).all()
 
 
 
 
194
 
195
  @app.delete("/prompts/{prompt_id}")
196
- async def delete_prompt(prompt_id: int, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
197
- db_prompt = db.query(models.SavedPrompt).filter(models.SavedPrompt.id == prompt_id, models.SavedPrompt.user_id == current_user.id).first()
198
- if not db_prompt:
199
- raise HTTPException(status_code=404, detail="Prompt not found")
200
- db.delete(db_prompt)
201
- db.commit()
202
- return {"status": "success"}
 
 
 
203
 
204
  # Admin Endpoints
205
- @app.get("/admin/users", response_model=List[schemas.UserActivity])
206
- async def get_all_users(current_user: models.User = Depends(get_current_admin), db: Session = Depends(get_db)):
207
- # Get users with message count
208
- users = db.query(models.User).all()
209
- result = []
210
- for user in users:
211
- msg_count = db.query(func.count(models.ChatMessage.id)).filter(models.ChatMessage.user_id == user.id).scalar()
212
- prompt_count = db.query(func.count(models.SavedPrompt.id)).filter(models.SavedPrompt.user_id == user.id).scalar()
213
- user_data = schemas.UserActivity.from_orm(user)
214
- user_data.message_count = msg_count
215
- user_data.prompt_count = prompt_count
216
- result.append(user_data)
217
- return result
218
 
219
- @app.get("/admin/activity", response_model=List[schemas.ChatMessage])
220
- async def get_all_activity(current_user: models.User = Depends(get_current_admin), db: Session = Depends(get_db)):
221
- messages = db.query(models.ChatMessage).order_by(models.ChatMessage.timestamp.desc()).limit(100).all()
222
- return messages
 
223
 
224
  # Protected AI Endpoints
225
  class ChatRequest(BaseModel):
@@ -236,22 +203,29 @@ def read_root():
236
  return {"status": "Backend is running", "message": "Go to /docs to see the API"}
237
 
238
  @app.post("/chat")
239
- async def chat(request: ChatRequest, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
240
  # ... (Keep existing /chat for backward compatibility if needed, or redirect logic)
241
  # For now, let's keep /chat as blocking and add /chat/stream
242
  try:
243
- # Save User Message
244
- user_msg = models.ChatMessage(user_id=current_user.id, role="user", content=request.message)
245
- db.add(user_msg)
246
- db.commit()
247
-
248
  # Generate Response
249
  response = chat_engine.generate_response(request.message, request.history)
250
 
251
- # Save Assistant Message
252
- ai_msg = models.ChatMessage(user_id=current_user.id, role="assistant", content=response)
253
- db.add(ai_msg)
254
- db.commit()
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  return {"response": response}
257
  except Exception as e:
@@ -261,7 +235,7 @@ async def chat(request: ChatRequest, current_user: models.User = Depends(get_cur
261
 
262
  # RAG Endpoints
263
  @app.post("/upload")
264
- async def upload_file(file: UploadFile = File(...), current_user: models.User = Depends(get_current_user)):
265
  try:
266
  # Save file locally
267
  upload_dir = "uploads"
@@ -279,7 +253,7 @@ async def upload_file(file: UploadFile = File(...), current_user: models.User =
279
  raise HTTPException(status_code=500, detail=str(e))
280
 
281
  @app.post("/chat/stream")
282
- async def chat_stream(request: ChatRequest, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)):
283
  try:
284
  # Check for RAG context
285
  context = ""
@@ -289,21 +263,14 @@ async def chat_stream(request: ChatRequest, current_user: models.User = Depends(
289
  print(f"Found {len(rag_docs)} relevant documents.")
290
 
291
  # Save User Message
292
- user_msg = models.ChatMessage(
293
- user_id=current_user.id,
294
- conversation_id=request.conversation_id,
295
- role="user",
296
- content=request.message
297
- )
298
- db.add(user_msg)
299
- db.commit()
300
-
301
- # Update conversation timestamp
302
  if request.conversation_id:
303
- conversation = db.query(models.Conversation).filter(models.Conversation.id == request.conversation_id).first()
304
- if conversation:
305
- conversation.updated_at = datetime.utcnow()
306
- db.commit()
 
 
 
307
 
308
  async def stream_generator():
309
  full_response = ""
@@ -314,6 +281,15 @@ async def chat_stream(request: ChatRequest, current_user: models.User = Depends(
314
  full_response += token
315
  yield token
316
 
 
 
 
 
 
 
 
 
 
317
  print(f"Generated response for conv {request.conversation_id}")
318
 
319
  return StreamingResponse(stream_generator(), media_type="text/plain")
@@ -324,7 +300,7 @@ async def chat_stream(request: ChatRequest, current_user: models.User = Depends(
324
  raise HTTPException(status_code=500, detail=str(e))
325
 
326
  @app.post("/generate-image")
327
- async def generate_image(request: ImageRequest, current_user: models.User = Depends(get_current_user)):
328
  try:
329
  # Generate image to a temporary file
330
  filename = "temp_generated.png"
 
16
  from image_engine import ImageEngine
17
  import models
18
  import schemas
19
+ import firebase_admin
20
+ from firebase_admin import credentials, firestore, auth
21
 
22
+ # Initialize Firebase Admin
23
+ if not firebase_admin._apps:
24
+ if os.path.exists("serviceAccountKey.json"):
25
+ cred = credentials.Certificate("serviceAccountKey.json")
26
+ else:
27
+ # Try getting from env var (for Hugging Face)
28
+ key_json = os.environ.get("FIREBASE_SERVICE_ACCOUNT_KEY")
29
+ if key_json:
30
+ import json
31
+ cred_dict = json.loads(key_json)
32
+ cred = credentials.Certificate(cred_dict)
33
+ else:
34
+ print("Warning: No service account key found. Firebase features will fail.")
35
+ cred = None
36
+
37
+ if cred:
38
+ firebase_admin.initialize_app(cred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ if firebase_admin._apps:
41
+ db = firestore.client()
42
+ else:
43
+ db = None
44
 
45
  # Initialize engines
46
  print("Initializing AI Engines...")
 
49
  rag_engine = RAGEngine()
50
  print("AI Engines Ready!")
51
 
52
+ # Auth Dependency
53
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
 
 
 
 
 
56
  try:
57
+ decoded_token = auth.verify_id_token(token)
58
+ uid = decoded_token['uid']
59
+ # Get user data from Firestore
60
+ user_doc = db.collection('users').document(uid).get()
61
+ if not user_doc.exists:
62
+ # Create user if not exists (first login)
63
+ user_data = {
64
+ "email": decoded_token.get('email'),
65
+ "full_name": decoded_token.get('name', 'User'),
66
+ "created_at": datetime.utcnow(),
67
+ "is_admin": False
68
+ }
69
+ db.collection('users').document(uid).set(user_data)
70
+ return {**user_data, "id": uid}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ return {**user_doc.to_dict(), "id": uid}
73
+ except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  raise HTTPException(
75
  status_code=status.HTTP_401_UNAUTHORIZED,
76
+ detail=f"Invalid authentication credentials: {str(e)}",
77
  headers={"WWW-Authenticate": "Bearer"},
78
  )
 
 
 
 
 
79
 
80
+ async def get_current_admin(current_user: dict = Depends(get_current_user)):
81
+ if not current_user.get("is_admin"):
82
+ raise HTTPException(status_code=403, detail="Not authorized")
83
+ return current_user
84
+
85
+ # Auth Endpoints
86
+ # Note: Registration and Login are handled by Firebase on the Frontend.
87
+ # The backend only verifies the ID token via get_current_user.
88
+
89
+ @app.get("/users/me")
90
+ async def read_users_me(current_user: dict = Depends(get_current_user)):
91
  return current_user
92
 
93
  # Conversation Endpoints
94
+ @app.post("/conversations")
95
+ async def create_conversation(conversation: schemas.ConversationCreate, current_user: dict = Depends(get_current_user)):
96
+ try:
97
+ new_conv_ref = db.collection('conversations').document()
98
+ conv_data = {
99
+ "id": new_conv_ref.id,
100
+ "user_id": current_user['id'],
101
+ "title": conversation.title,
102
+ "created_at": datetime.utcnow(),
103
+ "updated_at": datetime.utcnow()
104
+ }
105
+ new_conv_ref.set(conv_data)
106
+ return conv_data
107
+ except Exception as e:
108
+ raise HTTPException(status_code=500, detail=str(e))
109
 
110
+ @app.get("/conversations")
111
+ async def get_conversations(current_user: dict = Depends(get_current_user)):
112
+ try:
113
+ docs = db.collection('conversations').where('user_id', '==', current_user['id']).order_by('updated_at', direction=firestore.Query.DESCENDING).stream()
114
+ return [doc.to_dict() for doc in docs]
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=str(e))
117
 
118
+ @app.get("/conversations/{conversation_id}/messages")
119
+ async def get_conversation_messages(conversation_id: str, current_user: dict = Depends(get_current_user)):
120
+ try:
121
+ # Verify ownership
122
+ conv_ref = db.collection('conversations').document(conversation_id)
123
+ conv = conv_ref.get()
124
+ if not conv.exists or conv.to_dict()['user_id'] != current_user['id']:
125
+ raise HTTPException(status_code=404, detail="Conversation not found")
126
+
127
+ msgs = conv_ref.collection('messages').order_by('timestamp').stream()
128
+ return [msg.to_dict() for msg in msgs]
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=str(e))
131
 
132
  # Saved Prompt Endpoints
133
+ @app.post("/prompts")
134
+ async def create_prompt(prompt: schemas.SavedPromptCreate, current_user: dict = Depends(get_current_user)):
135
+ try:
136
+ new_prompt_ref = db.collection('prompts').document()
137
+ prompt_data = {
138
+ "id": new_prompt_ref.id,
139
+ "user_id": current_user['id'],
140
+ "title": prompt.title,
141
+ "content": prompt.content,
142
+ "tags": prompt.tags,
143
+ "created_at": datetime.utcnow()
144
+ }
145
+ new_prompt_ref.set(prompt_data)
146
+ return prompt_data
147
+ except Exception as e:
148
+ raise HTTPException(status_code=500, detail=str(e))
149
 
150
+ @app.get("/prompts")
151
+ async def get_prompts(current_user: dict = Depends(get_current_user)):
152
+ try:
153
+ docs = db.collection('prompts').where('user_id', '==', current_user['id']).order_by('created_at', direction=firestore.Query.DESCENDING).stream()
154
+ return [doc.to_dict() for doc in docs]
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=str(e))
157
 
158
  @app.delete("/prompts/{prompt_id}")
159
+ async def delete_prompt(prompt_id: str, current_user: dict = Depends(get_current_user)):
160
+ try:
161
+ prompt_ref = db.collection('prompts').document(prompt_id)
162
+ prompt = prompt_ref.get()
163
+ if not prompt.exists or prompt.to_dict()['user_id'] != current_user['id']:
164
+ raise HTTPException(status_code=404, detail="Prompt not found")
165
+ prompt_ref.delete()
166
+ return {"status": "success"}
167
+ except Exception as e:
168
+ raise HTTPException(status_code=500, detail=str(e))
169
 
170
  # Admin Endpoints
171
+ @app.get("/admin/users")
172
+ async def get_all_users(current_user: dict = Depends(get_current_admin)):
173
+ try:
174
+ users = db.collection('users').stream()
175
+ result = []
176
+ for user in users:
177
+ user_data = user.to_dict()
178
+ # Count messages (this might be expensive in Firestore, maybe skip or approximate)
179
+ # For now, let's just return user data
180
+ result.append(user_data)
181
+ return result
182
+ except Exception as e:
183
+ raise HTTPException(status_code=500, detail=str(e))
184
 
185
+ @app.get("/admin/activity")
186
+ async def get_all_activity(current_user: dict = Depends(get_current_admin)):
187
+ # This is hard in Firestore without a global collection group query
188
+ # For now, return empty or implement a specific 'activity' log collection
189
+ return []
190
 
191
  # Protected AI Endpoints
192
  class ChatRequest(BaseModel):
 
203
  return {"status": "Backend is running", "message": "Go to /docs to see the API"}
204
 
205
  @app.post("/chat")
206
+ async def chat(request: ChatRequest, current_user: dict = Depends(get_current_user)):
207
  # ... (Keep existing /chat for backward compatibility if needed, or redirect logic)
208
  # For now, let's keep /chat as blocking and add /chat/stream
209
  try:
 
 
 
 
 
210
  # Generate Response
211
  response = chat_engine.generate_response(request.message, request.history)
212
 
213
+ # Save to Firestore if conversation_id is present
214
+ if request.conversation_id:
215
+ conv_ref = db.collection('conversations').document(request.conversation_id)
216
+ # User Msg
217
+ conv_ref.collection('messages').add({
218
+ "role": "user",
219
+ "content": request.message,
220
+ "timestamp": datetime.utcnow()
221
+ })
222
+ # AI Msg
223
+ conv_ref.collection('messages').add({
224
+ "role": "assistant",
225
+ "content": response,
226
+ "timestamp": datetime.utcnow()
227
+ })
228
+ conv_ref.update({"updated_at": datetime.utcnow()})
229
 
230
  return {"response": response}
231
  except Exception as e:
 
235
 
236
  # RAG Endpoints
237
  @app.post("/upload")
238
+ async def upload_file(file: UploadFile = File(...), current_user: dict = Depends(get_current_user)):
239
  try:
240
  # Save file locally
241
  upload_dir = "uploads"
 
253
  raise HTTPException(status_code=500, detail=str(e))
254
 
255
  @app.post("/chat/stream")
256
+ async def chat_stream(request: ChatRequest, current_user: dict = Depends(get_current_user)):
257
  try:
258
  # Check for RAG context
259
  context = ""
 
263
  print(f"Found {len(rag_docs)} relevant documents.")
264
 
265
  # Save User Message
 
 
 
 
 
 
 
 
 
 
266
  if request.conversation_id:
267
+ conv_ref = db.collection('conversations').document(request.conversation_id)
268
+ conv_ref.collection('messages').add({
269
+ "role": "user",
270
+ "content": request.message,
271
+ "timestamp": datetime.utcnow()
272
+ })
273
+ conv_ref.update({"updated_at": datetime.utcnow()})
274
 
275
  async def stream_generator():
276
  full_response = ""
 
281
  full_response += token
282
  yield token
283
 
284
+ # Save AI Message after generation
285
+ if request.conversation_id:
286
+ conv_ref = db.collection('conversations').document(request.conversation_id)
287
+ conv_ref.collection('messages').add({
288
+ "role": "assistant",
289
+ "content": full_response,
290
+ "timestamp": datetime.utcnow()
291
+ })
292
+
293
  print(f"Generated response for conv {request.conversation_id}")
294
 
295
  return StreamingResponse(stream_generator(), media_type="text/plain")
 
300
  raise HTTPException(status_code=500, detail=str(e))
301
 
302
  @app.post("/generate-image")
303
+ async def generate_image(request: ImageRequest, current_user: dict = Depends(get_current_user)):
304
  try:
305
  # Generate image to a temporary file
306
  filename = "temp_generated.png"
requirements.txt CHANGED
@@ -19,6 +19,7 @@ langchain
19
  langchain-community
20
  langchain-text-splitters
21
  sentence-transformers
 
22
  faiss-cpu
23
  pypdf
24
  python-multipart
 
19
  langchain-community
20
  langchain-text-splitters
21
  sentence-transformers
22
+ firebase-admin
23
  faiss-cpu
24
  pypdf
25
  python-multipart