add life span manager to guarantee graceful db shutdown

parent 5b55ec50
import os import os
from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from typing import Optional from typing import Optional
import uvicorn import uvicorn
import base64 import base64
from pathlib import Path from pathlib import Path
# Import your existing modules # Import your existing modules
from core import AppConfig, StudentNationality from core import AppConfig
from repositories import StorageRepository, MinIOStorageRepository from repositories import MinIOStorageRepository
from handlers import AudioMessageHandler, TextMessageHandler
from services import ( from services import (
AudioService, ChatService, HealthService, ResponseService, AudioService, ChatService, HealthService, ResponseService,
ResponseManager, OpenAIService, AgentService, ConnectionPool, PGVectorService, ChatDatabaseService, LanguageSegmentationService ResponseManager, OpenAIService, AgentService, ConnectionPool, LanguageSegmentationService
) )
class DIContainer: class DIContainer:
...@@ -30,7 +30,7 @@ class DIContainer: ...@@ -30,7 +30,7 @@ class DIContainer:
dbname=os.getenv("POSTGRES_DB"), dbname=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"), user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"), password=os.getenv("POSTGRES_PASSWORD"),
host=os.getenv("DB_HOST"), # This is the crucial part host=os.getenv("DB_HOST"),
port=int(os.getenv("DB_PORT")) port=int(os.getenv("DB_PORT"))
) )
print(os.getenv("DB_HOST"), os.getenv("POSTGRES_DB"), os.getenv("POSTGRES_USER")) print(os.getenv("DB_HOST"), os.getenv("POSTGRES_DB"), os.getenv("POSTGRES_USER"))
...@@ -50,8 +50,30 @@ class DIContainer: ...@@ -50,8 +50,30 @@ class DIContainer:
self.response_service = ResponseService(self.response_manager, self.audio_service) self.response_service = ResponseService(self.response_manager, self.audio_service)
self.health_service = HealthService(self.storage_repo, self.config) self.health_service = HealthService(self.storage_repo, self.config)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages application startup and shutdown events for resource safety.
"""
# --- Code to run ON STARTUP ---
print("Application starting up...")
container = DIContainer()
app.state.container = container
print("DIContainer created and database pool initialized.")
yield # The application is now running and handling requests
# --- Code to run ON SHUTDOWN ---
print("Application shutting down...")
# This is the guaranteed, graceful shutdown call
app.state.container.agent_service.close()
print("Database connection pool closed successfully.")
def create_app() -> FastAPI: def create_app() -> FastAPI:
app = FastAPI(title="Unified Chat API with Local Agent") # Connect the lifespan manager to your FastAPI app instance
app = FastAPI(title="Unified Chat API with Local Agent", lifespan=lifespan)
# Fixed CORS configuration for CapRover # Fixed CORS configuration for CapRover
app.add_middleware( app.add_middleware(
...@@ -76,35 +98,33 @@ def create_app() -> FastAPI: ...@@ -76,35 +98,33 @@ def create_app() -> FastAPI:
expose_headers=["X-Response-Text"], expose_headers=["X-Response-Text"],
) )
# Initialize dependencies # NOTE: The container is now created and managed by the 'lifespan' function.
container = DIContainer() # No need to create it here.
# Print configuration
print("MinIO Endpoint:", container.config.minio_endpoint)
print("MinIO Bucket:", container.config.minio_bucket)
print("OpenAI Service Available:", container.openai_service.is_available())
print("Agent Service Available:", container.agent_service.is_available())
# Serve static files if the directory exists # Serve static files if the directory exists
static_path = Path("static") static_path = Path("static")
if static_path.exists(): if static_path.exists():
app.mount("/static", StaticFiles(directory=static_path), name="static") app.mount("/static", StaticFiles(directory=static_path), name="static")
@app.on_event("startup")
async def startup_event():
# Access the container from app state to print config on startup
container = app.state.container
print("MinIO Endpoint:", container.config.minio_endpoint)
print("MinIO Bucket:", container.config.minio_bucket)
print("OpenAI Service Available:", container.openai_service.is_available())
print("Agent Service Available:", container.agent_service.is_available())
@app.get("/chat-interface") @app.get("/chat-interface")
async def serve_audio_recorder(): async def serve_audio_recorder():
"""Serve the audio recorder HTML file""" """Serve the audio recorder HTML file"""
try: try:
# Try to serve from static directory first
static_file = Path("static/audio-recorder.html") static_file = Path("static/audio-recorder.html")
if static_file.exists(): if static_file.exists():
return FileResponse(static_file) return FileResponse(static_file)
# Fallback to current directory
current_file = Path("audio-recorder.html") current_file = Path("audio-recorder.html")
if current_file.exists(): if current_file.exists():
return FileResponse(current_file) return FileResponse(current_file)
# If no file found, return an error
raise HTTPException(status_code=404, detail="Audio recorder interface not found") raise HTTPException(status_code=404, detail="Audio recorder interface not found")
except Exception as e: except Exception as e:
print(f"Error serving audio recorder: {e}") print(f"Error serving audio recorder: {e}")
...@@ -112,57 +132,39 @@ def create_app() -> FastAPI: ...@@ -112,57 +132,39 @@ def create_app() -> FastAPI:
@app.post("/chat") @app.post("/chat")
async def chat_handler( async def chat_handler(
request: Request,
file: Optional[UploadFile] = File(None), file: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None), text: Optional[str] = Form(None),
student_id: str = Form("student_001"), student_id: str = Form("student_001"),
game_context: Optional[str] = Form(None) game_context: Optional[str] = Form(None)
): ):
""" """Handles incoming chat messages using the shared container instance."""
Handles incoming chat messages (either text or audio). container = request.app.state.container
Generates responses locally using the agent service.
"""
try: try:
if not student_id.strip(): if not student_id.strip():
raise HTTPException(status_code=400, detail="Student ID is required") raise HTTPException(status_code=400, detail="Student ID is required")
print(f"Processing message for student: {student_id}")
print(f"Text: {text}")
print(f"File: {file.filename if file else 'None'}")
result = container.chat_service.process_message( result = container.chat_service.process_message(
student_id=student_id, student_id=student_id,
file=file, file=file,
text=text, text=text,
game_context=game_context game_context=game_context
) )
print(f"Chat service result: {result}")
return result return result
except Exception as e: except Exception as e:
print(f"Error in chat handler: {str(e)}") print(f"Error in chat handler: {str(e)}")
raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}") raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
@app.get("/get-audio-response") @app.get("/get-audio-response")
async def get_audio_response(student_id: str = "student_001"): async def get_audio_response(request: Request, student_id: str = "student_001"):
"""Fetches the agent's text and audio response with proper CORS headers.""" """Fetches the agent's text and audio response using the shared container."""
container = request.app.state.container
try: try:
print("Getting audio response...")
result = container.response_service.get_agent_response(student_id=student_id) result = container.response_service.get_agent_response(student_id=student_id)
if hasattr(result, 'status_code'): if hasattr(result, 'status_code'):
# This is already a Response object from response_service
print(f"Response headers: {dict(result.headers)}")
print(f"Response audio raw bytes size: {len(result.body) if result.body else 'N/A'}")
print(f"Response audio first 20 bytes: {result.body[:20] if result.body else 'N/A'}")
return result return result
# This should be unreachable if response_service always returns a Response object
print(f"Created response with headers: {dict(response.headers)}") return result
return response
except Exception as e: except Exception as e:
print(f"Error getting audio response: {str(e)}") print(f"Error getting audio response: {str(e)}")
raise HTTPException(status_code=500, detail=f"Audio response error: {str(e)}") raise HTTPException(status_code=500, detail=f"Audio response error: {str(e)}")
...@@ -170,34 +172,19 @@ def create_app() -> FastAPI: ...@@ -170,34 +172,19 @@ def create_app() -> FastAPI:
@app.options("/chat") @app.options("/chat")
async def chat_options(): async def chat_options():
"""Handle preflight CORS requests for chat endpoint""" """Handle preflight CORS requests for chat endpoint"""
return Response( return Response(status_code=204, headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "*"})
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
"Access-Control-Max-Age": "86400"
}
)
@app.options("/get-audio-response") @app.options("/get-audio-response")
async def audio_response_options(): async def audio_response_options():
"""Handle preflight CORS requests for audio response endpoint""" """Handle preflight CORS requests for audio response endpoint"""
return Response( return Response(status_code=204, headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, OPTIONS", "Access-Control-Allow-Headers": "*", "Access-Control-Expose-Headers": "X-Response-Text"})
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
"Access-Control-Expose-Headers": "X-Response-Text",
"Access-Control-Max-Age": "86400"
}
)
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check(request: Request):
"""Health check endpoint with agent service status""" """Health check endpoint using the shared container."""
container = request.app.state.container
try: try:
health_status = container.health_service.get_health_status() health_status = container.health_service.get_health_status()
# Add agent service status
health_status.update({ health_status.update({
"openai_service_status": "available" if container.openai_service.is_available() else "unavailable", "openai_service_status": "available" if container.openai_service.is_available() else "unavailable",
"agent_service_status": "available" if container.agent_service.is_available() else "unavailable", "agent_service_status": "available" if container.agent_service.is_available() else "unavailable",
...@@ -211,50 +198,46 @@ def create_app() -> FastAPI: ...@@ -211,50 +198,46 @@ def create_app() -> FastAPI:
# Agent management endpoints # Agent management endpoints
@app.get("/conversation/stats") @app.get("/conversation/stats")
async def get_conversation_stats(student_id: str = "student_001"): async def get_conversation_stats(request: Request, student_id: str = "student_001"):
"""Get conversation statistics""" container = request.app.state.container
try: try:
return container.chat_service.get_agent_stats(student_id) return container.chat_service.get_agent_stats(student_id)
except Exception as e: except Exception as e:
print(f"Stats error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/conversation/clear") @app.post("/conversation/clear")
async def clear_conversation(student_id: str = Form("student_001")): async def clear_conversation(request: Request, student_id: str = Form("student_001")):
"""Clear conversation history""" container = request.app.state.container
try: try:
return container.chat_service.clear_conversation(student_id) return container.chat_service.clear_conversation(student_id)
except Exception as e: except Exception as e:
print(f"Clear conversation error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/agent/system-prompt") @app.post("/agent/system-prompt")
async def set_system_prompt(request: dict): async def set_system_prompt(req_body: dict, request: Request):
"""Update the agent's system prompt""" container = request.app.state.container
try: try:
prompt = request.get("prompt", "") prompt = req_body.get("prompt", "")
if not prompt: if not prompt:
raise HTTPException(status_code=400, detail="System prompt cannot be empty") raise HTTPException(status_code=400, detail="System prompt cannot be empty")
return container.chat_service.set_system_prompt(prompt) return container.chat_service.set_system_prompt(prompt)
except Exception as e: except Exception as e:
print(f"Set system prompt error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/agent/system-prompt") @app.get("/agent/system-prompt")
async def get_system_prompt(): async def get_system_prompt(request: Request):
"""Get the current system prompt""" container = request.app.state.container
try: try:
return { return {
"system_prompt": container.agent_service.system_prompt, "system_prompt": container.agent_service.system_prompt,
"status": "success" "status": "success"
} }
except Exception as e: except Exception as e:
print(f"Get system prompt error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/conversation/export") @app.get("/conversation/export")
async def export_conversation(student_id: str = "student_001"): async def export_conversation(request: Request, student_id: str = "student_001"):
"""Export conversation history""" container = request.app.state.container
try: try:
history = container.agent_service.export_conversation(student_id) history = container.agent_service.export_conversation(student_id)
return { return {
...@@ -263,73 +246,32 @@ def create_app() -> FastAPI: ...@@ -263,73 +246,32 @@ def create_app() -> FastAPI:
"total_messages": len(history) "total_messages": len(history)
} }
except Exception as e: except Exception as e:
print(f"Export conversation error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/conversation/import") @app.post("/conversation/import")
async def import_conversation(request: dict): async def import_conversation(req_body: dict, request: Request):
"""Import conversation history""" container = request.app.state.container
try: try:
student_id = request.get("student_id", "student_001") student_id = req_body.get("student_id", "student_001")
messages = request.get("messages", []) messages = req_body.get("messages", [])
if not messages: if not messages:
raise HTTPException(status_code=400, detail="Messages list cannot be empty") raise HTTPException(status_code=400, detail="Messages list cannot be empty")
container.agent_service.import_conversation(messages, student_id) container.agent_service.import_conversation(messages, student_id)
return { return {"status": "success", "message": f"Imported {len(messages)} messages"}
"status": "success",
"message": f"Imported {len(messages)} messages to conversation {student_id}"
}
except Exception as e: except Exception as e:
print(f"Import conversation error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/debug/test-response") @app.get("/debug/test-response")
async def debug_test_response(): async def debug_test_response():
"""Debug endpoint to test response generation""" """Debug endpoint to test response generation"""
try:
# Test basic response
test_text = "This is a test response" test_text = "This is a test response"
encoded_text = base64.b64encode(test_text.encode('utf-8')).decode('utf-8') encoded_text = base64.b64encode(test_text.encode('utf-8')).decode('utf-8')
return Response(content=b"test audio data", media_type="audio/mpeg", headers={"X-Response-Text": encoded_text, "Access-Control-Expose-Headers": "X-Response-Text"})
return Response(
content=b"test audio data",
media_type="audio/mpeg",
headers={
"X-Response-Text": encoded_text,
"Access-Control-Expose-Headers": "X-Response-Text"
}
)
except Exception as e:
print(f"Debug test error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/") @app.get("/")
async def root(): async def root():
"""Root endpoint with API info""" """Root endpoint with API info"""
return { return {"service": "Unified Chat API with Local Agent", "version": "2.2.0-lifespan", "status": "running"}
"service": "Unified Chat API with Local Agent",
"version": "2.1.0",
"description": "Unified backend for audio/text chat with a local AI agent.",
"status": "running",
"deployment": "CapRover",
"features": [
"Local AI agent responses using OpenAI GPT",
"Audio transcription using OpenAI Whisper",
"Text-to-speech using OpenAI TTS",
"Conversation history management",
"Student-specific conversations",
"CORS enabled for cross-origin requests"
],
"endpoints": {
"chat_interface": "/chat-interface (HTML interface)",
"chat": "/chat (accepts audio or text with student_id)",
"get_audio_response": "/get-audio-response (fetches agent's audio and text)",
"health": "/health (service health check)",
"debug": "/debug/test-response (test response generation)"
}
}
return app return app
......
...@@ -293,6 +293,3 @@ class AgentService: ...@@ -293,6 +293,3 @@ class AgentService:
except Exception as e: except Exception as e:
logger.error(f"Error closing connection pools: {e}") logger.error(f"Error closing connection pools: {e}")
def __del__(self):
"""Destructor to ensure connection pools are closed"""
self.close()
\ No newline at end of file
...@@ -40,43 +40,44 @@ class ResponseManager: ...@@ -40,43 +40,44 @@ class ResponseManager:
def get_response(self, student_id: str) -> Dict: def get_response(self, student_id: str) -> Dict:
""" """
Atomically gets the response for a student and removes it from Redis Gets the response for a student without deleting it.
to ensure it's claimed only once. This allows the client to safely retry the request if it fails.
The key will be cleaned up automatically by Redis when its TTL expires.
""" """
key = self._get_key(student_id) key = self._get_key(student_id)
# Use a pipeline to get and delete the key in a single, atomic operation # 1. Use a simple, non-destructive GET command. No pipeline needed.
pipe = self.redis.pipeline() json_value = self.redis.get(key)
pipe.get(key)
pipe.delete(key)
results = pipe.execute()
json_value = results[0]
if not json_value: if not json_value:
# If nothing was found, return the same empty structure as the old class return {"text": None, "audio_filepath": None, "audio_bytes": None}
return {"text": None, "audio_filename": None, "audio_bytes": None}
# If data was found, decode it # 2. Decode the payload as before.
payload = json.loads(json_value) payload = json.loads(json_value)
# Decode the Base64 string back into binary audio data
if payload.get("audio_bytes_b64"): if payload.get("audio_bytes_b64"):
payload["audio_bytes"] = base64.b64decode(payload["audio_bytes_b64"]) payload["audio_bytes"] = base64.b64decode(payload["audio_bytes_b64"])
else: else:
payload["audio_bytes"] = None payload["audio_bytes"] = None
# Remove the temporary key before returning
del payload["audio_bytes_b64"] del payload["audio_bytes_b64"]
return payload return payload
def clear_response(self, student_id: str) -> None: def clear_response(self, student_id: str) -> None:
"""Clears a response for a specific student from Redis.""" """
Clears a response for a specific student from Redis.
This is still important to call at the *beginning* of a new /chat request
to ensure old data is invalidated immediately.
"""
key = self._get_key(student_id) key = self._get_key(student_id)
self.redis.delete(key) self.redis.delete(key)
def is_response_fresh(self, student_id: str, max_age_seconds: int = 300) -> bool: def is_response_fresh(self, student_id: str) -> bool:
"""Checks if a response exists in Redis for the given student.""" """
Checks if a response exists in Redis for the given student.
This is much simpler and more reliable now.
"""
key = self._get_key(student_id) key = self._get_key(student_id)
# redis.exists() is the direct equivalent of checking if the key is present # redis.exists() returns the number of keys that exist (0 or 1 in this case).
return self.redis.exists(key) > 0 return self.redis.exists(key) > 0
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment