Commit 754e2486 authored by salma's avatar salma

Add pydantic validation for simpler code

parent 8ea8f2f5
# Network settings
bind = "0.0.0.0:8000"
bind = "0.0.0.0:8001"
# Worker settings
worker_class = "uvicorn.workers.UvicornWorker"
......
......@@ -58,7 +58,7 @@ async def lifespan(app: FastAPI):
pass
if hasattr(app.state.container, 'agent_service'):
app.state.container.agent_service.close()
app.state.container.pool_handler.close_all()
print("Database connection pool closed.")
def create_app() -> FastAPI:
......
......@@ -105,7 +105,7 @@ async def websocket_endpoint(websocket: WebSocket, room_id: str, student_id: str
# 3. Update Participants in DB
logger.info(f"Fetching student info for {student_id}")
student_info = container.mcq_service.db_service.get_student_info(student_id)
student_name = student_info['student_name'] if student_info else "Unknown Student"
student_name = student_info.student_name if student_info else "Unknown Student"
room_data = redis_client.hgetall(room_key)
participants = json.loads(room_data.get("participants", "{}"))
......
from .response import WebhookResponse
from .message import TextMessage
from .mcq import QuestionResponse, QuizResponse, MCQListResponse
\ No newline at end of file
from .mcq import QuestionResponse, QuizResponse, MCQListResponse
from .agent import (
StudentProfile,
SearchResult,
CurriculumContext,
AgentRequest,
ChatMessage
)
\ No newline at end of file
# --- START OF FILE schemas/agent.py ---
from pydantic import BaseModel, Field, validator
from typing import Optional, List, Dict, Union
from enum import Enum
import os
from pathlib import Path
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudentNationality, StudyLanguage
# --- Student Context Models ---
class StudentProfile(BaseModel):
student_id: str
student_name: str
grade: int = Field(..., ge=1, le=12) # Validate grade is between 1 and 12
study_language: StudyLanguage
nationality: Union[StudentNationality, str] # Handle both Enum obj and string from DB
is_arabic: bool
@validator('nationality', pre=True)
def parse_nationality(cls, v):
if isinstance(v, str):
# Normalize string input to the Enum
try:
return StudentNationality(v.lower().strip())
except ValueError:
return StudentNationality.EGYPTIAN # Fallback
return v
# --- Search/Content Models ---
class CurriculumContext(BaseModel):
position: Optional[str] = None
related_concepts: List[str] = []
unit_overview: Optional[str] = None
navigation_hint: Optional[str] = None
class SearchResult(BaseModel):
id: int
chunk_text: str
unit: Optional[str] = None
concept: Optional[str] = None
lesson: Optional[str] = None
distance: float
curriculum_context: Optional[CurriculumContext] = None
# Allow extra fields from DB that we might not explicitly need yet
class Config:
extra = "ignore"
# --- Interaction Models ---
class AgentRequest(BaseModel):
user_message: str
student_id: str
subject: str = "Science"
model: str = "gpt-4o" # Default model
temperature: float = Field(0.3, ge=0.0, le=2.0)
top_k: int = Field(3, ge=1, le=10)
class ChatMessage(BaseModel):
role: str
content: str
\ No newline at end of file
......@@ -3,8 +3,9 @@ import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from typing import Dict, Tuple
from typing import List
from core import StudentNationality, StudyLanguage
from schemas import StudentProfile, SearchResult
logger = logging.getLogger(__name__)
......@@ -18,14 +19,15 @@ class ContextGenerator:
self.openai_service = openai_service
self.pgvector = pgvector_service
def generate_enhanced_context(self, search_results: list[Dict], student_info: Dict, query_type: str) -> str:
def generate_enhanced_context(self, search_results: List[SearchResult], student_info: StudentProfile, query_type: str) -> str:
"""Generate enhanced context with JSON-based curriculum structure awareness"""
if not search_results:
return ""
is_arabic = student_info['is_arabic']
study_language = student_info['study_language']
grade = student_info['grade']
# Access via Dot Notation
is_arabic = student_info.is_arabic
study_language = student_info.study_language
grade = student_info.grade
if study_language == StudyLanguage.ENGLISH:
context_message = f"📚 من المنهج الإنجليزي لمادة العلوم للصف {grade}:\n\n"
......@@ -33,10 +35,10 @@ class ContextGenerator:
context_message = f"📚 من المنهج العربي لمادة العلوم للصف {grade}:\n\n"
for result in search_results:
# Basic information
unit_info = f"الوحدة: {result['unit']}" if result.get('unit') else ""
concept_info = f"المفهوم: {result['concept']}" if result.get('concept') else ""
lesson_info = f"الدرس: {result['lesson']}" if result.get('lesson') else ""
# Access via Dot Notation
unit_info = f"الوحدة: {result.unit}" if result.unit else ""
concept_info = f"المفهوم: {result.concept}" if result.concept else ""
lesson_info = f"الدرس: {result.lesson}" if result.lesson else ""
# Build header
context_parts = [info for info in [unit_info, concept_info, lesson_info] if info]
......@@ -44,16 +46,16 @@ class ContextGenerator:
context_message += f"**{' → '.join(context_parts)}**\n"
# Add content
context_message += f"{result['chunk_text']}\n"
context_message += f"{result.chunk_text}\n"
# Add curriculum context if available
if 'curriculum_context' in result:
ctx = result['curriculum_context']
if ctx.get('navigation_hint'):
context_message += f"\n💡 {ctx['navigation_hint']}\n"
if result.curriculum_context:
ctx = result.curriculum_context
if ctx.navigation_hint:
context_message += f"\n💡 {ctx.navigation_hint}\n"
if ctx.get('related_concepts') and query_type == "specific_content":
related = ', '.join(ctx['related_concepts'][:3])
if ctx.related_concepts and query_type == "specific_content":
related = ', '.join(ctx.related_concepts[:3])
if is_arabic:
context_message += f"🔗 مفاهيم ذات صلة: {related}\n"
else:
......@@ -69,22 +71,24 @@ class ContextGenerator:
return context_message
def search_enhanced_content(self, query: str, student_info: Dict, subject: str, top_k: int = 3) -> list[Dict]:
def search_enhanced_content(self, query: str, student_info: StudentProfile, subject: str, top_k: int = 3) -> List[SearchResult]:
"""Search for enhanced content with curriculum context"""
if not self.pgvector:
return []
try:
query_embedding = self.openai_service.generate_embedding(query)
# PGVector now returns List[SearchResult]
search_results = self.pgvector.search_with_curriculum_context(
query_embedding=query_embedding,
grade=student_info['grade'],
grade=student_info.grade,
subject=subject,
is_arabic=student_info['is_arabic'],
is_arabic=student_info.is_arabic,
limit=top_k
)
relevant_results = [r for r in search_results if r['distance'] < 1.3] if search_results else []
# Access via Dot Notation
relevant_results = [r for r in search_results if r.distance < 1.3] if search_results else []
return relevant_results
except Exception as e:
......
......@@ -3,6 +3,7 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from typing import Dict, Any
from schemas import StudentProfile
from core import StudentNationality, StudyLanguage, Models
import logging
......@@ -93,7 +94,7 @@ class QueryHandler:
return "لا يمكن الحصول على سياق المحادثة."
def classify_query_type(self, query: str, student_info: Dict[str, Any], student_id: str) -> str:
def classify_query_type(self, query: str, student_info: StudentProfile, student_id: str) -> str:
"""
Enhanced query classification. It first checks for a specific, rule-based
pattern for 'game_help' and then falls back to the LLM for other cases.
......@@ -109,8 +110,8 @@ class QueryHandler:
if not self.openai_service.is_available():
return "specific_content"
is_arabic: bool = student_info.get('is_arabic', True)
grade: int = student_info.get('grade', 4)
is_arabic = student_info.is_arabic
grade = student_info.grade
conversation_context = self.get_recent_conversation_context(student_id, max_messages=5)
......@@ -173,24 +174,18 @@ class QueryHandler:
logger.warning(f"Error in query classification: {e}, defaulting to 'specific_content'")
return "specific_content"
def handle_general_chat_query(self, query: str, student_info: Dict[str, Any]) -> str:
def handle_general_chat_query(self, query: str, student_info: StudentProfile) -> str:
"""Handle general chat queries using only student information"""
student_name: str = student_info.get('student_name', 'الطالب')
grade: int = student_info.get('grade', 4)
nationality_str: str = student_info.get('nationality', 'egyptian')
is_arabic: bool = student_info.get('is_arabic', True)
student_name = student_info.student_name
grade = student_info.grade
nationality_enum = student_info.nationality
is_arabic = student_info.is_arabic
study_lang = "العربية" if is_arabic else "الإنجليزية"
# Map nationality string to enum
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
}
nationality_enum = nationality_mapping.get(nationality_str.lower().strip(), StudentNationality.EGYPTIAN)
# Get template with fallback
nationality_str = nationality_enum.value if hasattr(nationality_enum, 'value') else str(nationality_enum)
template = GENERAL_CHAT_CONTEXTS.get(nationality_enum)
if not template:
logger.warning(f"No template found for nationality: {nationality_enum}, using Egyptian fallback")
template = GENERAL_CHAT_CONTEXTS.get(StudentNationality.EGYPTIAN)
......@@ -229,18 +224,18 @@ class QueryHandler:
"""
def handle_overview_query(self, student_info: Dict[str, Any], subject: str = "Science") -> str:
def handle_overview_query(self, student_info: StudentProfile, subject: str = "Science") -> str:
"""Handle curriculum overview queries using JSON-based data"""
if not self.pgvector:
if student_info['study_language'] == StudyLanguage.ARABIC:
return f"عذراً، لا يمكنني عرض المنهج حالياً للصف {student_info['grade']}"
if student_info.study_language == StudyLanguage.ARABIC:
return f"عذراً، لا يمكنني عرض المنهج حالياً للصف {student_info.grade}"
else:
return f"Sorry, I cannot show the curriculum for Grade {student_info['grade']} right now"
return f"Sorry, I cannot show the curriculum for Grade {student_info.grade} right now"
try:
return self.pgvector.get_overview_response(
student_info['grade'],
student_info['is_arabic'],
student_info.grade,
student_info.is_arabic,
subject
)
except Exception as e:
......@@ -250,7 +245,7 @@ class QueryHandler:
else:
return f"Sorry, there was an error showing the curriculum for Grade {student_info['grade']}"
def handle_navigation_query(self, query: str, student_info: Dict[str, Any], subject: str = "Science") -> str:
def handle_navigation_query(self, query: str, student_info: StudentProfile, subject: str = "Science") -> str:
"""Handle unit/concept navigation queries using JSON structure"""
if not self.pgvector:
return self.handle_overview_query(student_info, subject)
......@@ -258,8 +253,8 @@ class QueryHandler:
try:
return self.pgvector.get_unit_navigation_response(
query,
student_info['grade'],
student_info['is_arabic'],
student_info.grade,
student_info.is_arabic,
subject
)
except Exception as e:
......
......@@ -5,6 +5,7 @@ from fastapi import HTTPException
from services.agent_helpers.agent_prompts import SYSTEM_PROMPTS
import logging
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from schemas import StudentProfile
from core import StudentNationality, StudyLanguage, Models
logger = logging.getLogger(__name__)
......@@ -37,18 +38,12 @@ class ResponseGenerator:
except Exception as e:
logger.error(f"Error adding message to history for {student_id}: {e}")
def prepare_system_prompt(self, student_info: Dict) -> str:
def prepare_system_prompt(self, student_info: StudentProfile) -> str:
"""Prepare system prompt based on student information"""
student_name = student_info.get('student_name', 'الطالب').split()[0]
study_language = student_info['study_language']
# Map nationality
nationality_lower = student_info['nationality'].lower().strip()
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
}
nationality = nationality_mapping.get(nationality_lower, StudentNationality.EGYPTIAN)
# Dot notation
student_name = student_info.student_name.split()[0]
study_language = student_info.study_language
nationality = student_info.nationality
# Get appropriate system prompt
prompt_key = (nationality, study_language)
......@@ -57,12 +52,12 @@ class ResponseGenerator:
formatted_base_prompt = base_system_prompt.format(
student_name=student_name,
grade=student_info['grade']
grade=student_info.grade
)
# Add Socratic instructions if any
# Add Socratic instructions
socratic_instructions = self.pedagogy_service.get_socratic_instructions(
student_info['grade'], student_info['nationality']
student_info.grade, student_info.nationality
)
if socratic_instructions:
formatted_base_prompt += f"\n\n{socratic_instructions}"
......@@ -87,9 +82,8 @@ class ResponseGenerator:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found")
student_name = student_info.get('student_name', 'الطالب').split()[0]
study_language = student_info['study_language']
student_name = student_info.student_name.split()[0]
study_language = student_info.study_language
# Add user message to DB
self.add_message_to_history(student_id, user_message, "user")
......
......@@ -16,6 +16,7 @@ from services.agent_helpers.response_generator import ResponseGenerator
from services.agent_helpers.tashkeel_agent import TashkeelAgent
from services.agent_helpers.tashkeel_fixer import apply_fixes, custom_fixes
from services.tts.tts_manager import get_tts_service
from schemas import AgentRequest
logger = logging.getLogger(__name__)
......@@ -65,6 +66,7 @@ class AgentService:
def generate_response(self, user_message: str, student_id: str, subject: str = "Science",
model: str = Models.chat, temperature: float = 0.3, top_k: int = 3):
"""
Main response generation method, now handles both string and dictionary responses.
"""
......@@ -86,12 +88,3 @@ class AgentService:
print(f"response: {response}")
return response
def close(self):
"""Close database connection pools"""
if self.pool_handler:
try:
self.pool_handler.close_all()
except Exception as e:
logger.error(f"Error closing connection pools: {e}")
\ No newline at end of file
......@@ -9,6 +9,7 @@ import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudyLanguage
from schemas import StudentProfile
logger = logging.getLogger(__name__)
......@@ -30,7 +31,7 @@ class ChatDatabaseService:
result = cur.fetchone()
return result["nationality"] if result else None
def get_student_info(self, student_id: str) -> Optional[Dict]:
def get_student_info(self, student_id: str) -> Optional[StudentProfile]:
"""Get complete student information with explicit language awareness"""
with self.pool_handler.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
......@@ -44,16 +45,16 @@ class ChatDatabaseService:
)
result = cur.fetchone()
if result:
# Convert boolean to explicit language enum
study_language = StudyLanguage.ARABIC if result['language'] else StudyLanguage.ENGLISH
return {
'student_id': result['student_id'],
'student_name': result['student_name'],
'grade': result['grade'],
'study_language': study_language, # Explicit language enum
'is_arabic': result['language'], # Keep for backward compatibility
'nationality': result['nationality']
}
return StudentProfile(
student_id=result['student_id'],
student_name=result['student_name'],
grade=result['grade'],
study_language=study_language,
nationality=result['nationality'],
is_arabic=result['language']
)
return None
def get_student_grade_and_language(self, student_id: str) -> Optional[Tuple[int, bool, StudyLanguage]]:
......@@ -192,16 +193,24 @@ class ChatDatabaseService:
if not student_info:
return None
nationality_desc = "مصري" if student_info['nationality'].lower() == "egyptian" else "سعودي"
language_desc = "بالعربي" if student_info['study_language'] == StudyLanguage.ARABIC else "بالإنجليزي"
is_egyptian = False
if hasattr(student_info.nationality, 'value'):
# If it's the Enum object
is_egyptian = student_info.nationality.value.lower() == "egyptian"
else:
# If strictly string fallback
is_egyptian = str(student_info.nationality).lower() == "egyptian"
nationality_desc = "مصري" if is_egyptian else "سعودي"
language_desc = "بالعربي" if student_info.study_language == StudyLanguage.ARABIC else "بالإنجليزي"
return {
"student_id": student_id,
"student_name": student_info['student_name'],
"study_language": student_info['study_language'].value,
"nationality": student_info['nationality'],
"grade": str(student_info['grade']),
"description": f"طالب {nationality_desc} في الصف {student_info['grade']} يدرس {language_desc}"
"student_name": student_info.student_name,
"study_language": student_info.study_language.value,
"nationality": str(student_info.nationality.value if hasattr(student_info.nationality, 'value') else student_info.nationality),
"grade": str(student_info.grade),
"description": f"طالب {nationality_desc} في الصف {student_info.grade} يدرس {language_desc}"
}
def get_students_by_language(self, study_language: StudyLanguage) -> List[Dict]:
......@@ -229,4 +238,6 @@ class ChatDatabaseService:
'study_language': study_language.value
}
for row in results
]
\ No newline at end of file
]
\ No newline at end of file
......@@ -6,6 +6,10 @@ import logging
import json
from pgvector.psycopg2 import register_vector
from services.connection_pool import ConnectionPool
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from schemas import SearchResult, CurriculumContext
logger = logging.getLogger(__name__)
......@@ -81,7 +85,7 @@ class PGVectorService:
subject: str,
is_arabic: bool,
limit: int = 3
):
) -> List[SearchResult]: # <--- Update return type hint
"""Enhanced search that includes curriculum position context"""
with self.pool_handler.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
......@@ -104,13 +108,27 @@ class PGVectorService:
)
results = cur.fetchall()
# Enhance results with curriculum context
# Convert to Pydantic Models
pydantic_results = []
for result in results:
result['curriculum_context'] = self._build_curriculum_context(
# Build context dict
ctx_dict = self._build_curriculum_context(
result, curriculum, grade, is_arabic
)
# Create Model
search_result = SearchResult(
id=result['id'],
chunk_text=result['chunk_text'],
unit=result.get('unit'),
concept=result.get('concept'),
lesson=result.get('lesson'),
distance=result['distance'],
curriculum_context=CurriculumContext(**ctx_dict) # Validate nested
)
pydantic_results.append(search_result)
return results
return pydantic_results
def search_flexible_filtered_nearest(
self,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment