Commit 754e2486 authored by salma's avatar salma

Add pydantic validation for simpler code

parent 8ea8f2f5
# Network settings # Network settings
bind = "0.0.0.0:8000" bind = "0.0.0.0:8001"
# Worker settings # Worker settings
worker_class = "uvicorn.workers.UvicornWorker" worker_class = "uvicorn.workers.UvicornWorker"
......
...@@ -58,7 +58,7 @@ async def lifespan(app: FastAPI): ...@@ -58,7 +58,7 @@ async def lifespan(app: FastAPI):
pass pass
if hasattr(app.state.container, 'agent_service'): if hasattr(app.state.container, 'agent_service'):
app.state.container.agent_service.close() app.state.container.pool_handler.close_all()
print("Database connection pool closed.") print("Database connection pool closed.")
def create_app() -> FastAPI: def create_app() -> FastAPI:
......
...@@ -105,7 +105,7 @@ async def websocket_endpoint(websocket: WebSocket, room_id: str, student_id: str ...@@ -105,7 +105,7 @@ async def websocket_endpoint(websocket: WebSocket, room_id: str, student_id: str
# 3. Update Participants in DB # 3. Update Participants in DB
logger.info(f"Fetching student info for {student_id}") logger.info(f"Fetching student info for {student_id}")
student_info = container.mcq_service.db_service.get_student_info(student_id) student_info = container.mcq_service.db_service.get_student_info(student_id)
student_name = student_info['student_name'] if student_info else "Unknown Student" student_name = student_info.student_name if student_info else "Unknown Student"
room_data = redis_client.hgetall(room_key) room_data = redis_client.hgetall(room_key)
participants = json.loads(room_data.get("participants", "{}")) participants = json.loads(room_data.get("participants", "{}"))
......
from .response import WebhookResponse from .response import WebhookResponse
from .message import TextMessage from .message import TextMessage
from .mcq import QuestionResponse, QuizResponse, MCQListResponse from .mcq import QuestionResponse, QuizResponse, MCQListResponse
from .agent import (
StudentProfile,
SearchResult,
CurriculumContext,
AgentRequest,
ChatMessage
)
\ No newline at end of file
# --- START OF FILE schemas/agent.py ---
from pydantic import BaseModel, Field, validator
from typing import Optional, List, Dict, Union
from enum import Enum
import os
from pathlib import Path
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudentNationality, StudyLanguage
# --- Student Context Models ---
class StudentProfile(BaseModel):
student_id: str
student_name: str
grade: int = Field(..., ge=1, le=12) # Validate grade is between 1 and 12
study_language: StudyLanguage
nationality: Union[StudentNationality, str] # Handle both Enum obj and string from DB
is_arabic: bool
@validator('nationality', pre=True)
def parse_nationality(cls, v):
if isinstance(v, str):
# Normalize string input to the Enum
try:
return StudentNationality(v.lower().strip())
except ValueError:
return StudentNationality.EGYPTIAN # Fallback
return v
# --- Search/Content Models ---
class CurriculumContext(BaseModel):
position: Optional[str] = None
related_concepts: List[str] = []
unit_overview: Optional[str] = None
navigation_hint: Optional[str] = None
class SearchResult(BaseModel):
id: int
chunk_text: str
unit: Optional[str] = None
concept: Optional[str] = None
lesson: Optional[str] = None
distance: float
curriculum_context: Optional[CurriculumContext] = None
# Allow extra fields from DB that we might not explicitly need yet
class Config:
extra = "ignore"
# --- Interaction Models ---
class AgentRequest(BaseModel):
user_message: str
student_id: str
subject: str = "Science"
model: str = "gpt-4o" # Default model
temperature: float = Field(0.3, ge=0.0, le=2.0)
top_k: int = Field(3, ge=1, le=10)
class ChatMessage(BaseModel):
role: str
content: str
\ No newline at end of file
...@@ -3,8 +3,9 @@ import os ...@@ -3,8 +3,9 @@ import os
import sys import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from typing import Dict, Tuple from typing import List
from core import StudentNationality, StudyLanguage from core import StudentNationality, StudyLanguage
from schemas import StudentProfile, SearchResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -18,14 +19,15 @@ class ContextGenerator: ...@@ -18,14 +19,15 @@ class ContextGenerator:
self.openai_service = openai_service self.openai_service = openai_service
self.pgvector = pgvector_service self.pgvector = pgvector_service
def generate_enhanced_context(self, search_results: list[Dict], student_info: Dict, query_type: str) -> str: def generate_enhanced_context(self, search_results: List[SearchResult], student_info: StudentProfile, query_type: str) -> str:
"""Generate enhanced context with JSON-based curriculum structure awareness""" """Generate enhanced context with JSON-based curriculum structure awareness"""
if not search_results: if not search_results:
return "" return ""
is_arabic = student_info['is_arabic'] # Access via Dot Notation
study_language = student_info['study_language'] is_arabic = student_info.is_arabic
grade = student_info['grade'] study_language = student_info.study_language
grade = student_info.grade
if study_language == StudyLanguage.ENGLISH: if study_language == StudyLanguage.ENGLISH:
context_message = f"📚 من المنهج الإنجليزي لمادة العلوم للصف {grade}:\n\n" context_message = f"📚 من المنهج الإنجليزي لمادة العلوم للصف {grade}:\n\n"
...@@ -33,10 +35,10 @@ class ContextGenerator: ...@@ -33,10 +35,10 @@ class ContextGenerator:
context_message = f"📚 من المنهج العربي لمادة العلوم للصف {grade}:\n\n" context_message = f"📚 من المنهج العربي لمادة العلوم للصف {grade}:\n\n"
for result in search_results: for result in search_results:
# Basic information # Access via Dot Notation
unit_info = f"الوحدة: {result['unit']}" if result.get('unit') else "" unit_info = f"الوحدة: {result.unit}" if result.unit else ""
concept_info = f"المفهوم: {result['concept']}" if result.get('concept') else "" concept_info = f"المفهوم: {result.concept}" if result.concept else ""
lesson_info = f"الدرس: {result['lesson']}" if result.get('lesson') else "" lesson_info = f"الدرس: {result.lesson}" if result.lesson else ""
# Build header # Build header
context_parts = [info for info in [unit_info, concept_info, lesson_info] if info] context_parts = [info for info in [unit_info, concept_info, lesson_info] if info]
...@@ -44,16 +46,16 @@ class ContextGenerator: ...@@ -44,16 +46,16 @@ class ContextGenerator:
context_message += f"**{' → '.join(context_parts)}**\n" context_message += f"**{' → '.join(context_parts)}**\n"
# Add content # Add content
context_message += f"{result['chunk_text']}\n" context_message += f"{result.chunk_text}\n"
# Add curriculum context if available # Add curriculum context if available
if 'curriculum_context' in result: if result.curriculum_context:
ctx = result['curriculum_context'] ctx = result.curriculum_context
if ctx.get('navigation_hint'): if ctx.navigation_hint:
context_message += f"\n💡 {ctx['navigation_hint']}\n" context_message += f"\n💡 {ctx.navigation_hint}\n"
if ctx.get('related_concepts') and query_type == "specific_content": if ctx.related_concepts and query_type == "specific_content":
related = ', '.join(ctx['related_concepts'][:3]) related = ', '.join(ctx.related_concepts[:3])
if is_arabic: if is_arabic:
context_message += f"🔗 مفاهيم ذات صلة: {related}\n" context_message += f"🔗 مفاهيم ذات صلة: {related}\n"
else: else:
...@@ -69,22 +71,24 @@ class ContextGenerator: ...@@ -69,22 +71,24 @@ class ContextGenerator:
return context_message return context_message
def search_enhanced_content(self, query: str, student_info: Dict, subject: str, top_k: int = 3) -> list[Dict]: def search_enhanced_content(self, query: str, student_info: StudentProfile, subject: str, top_k: int = 3) -> List[SearchResult]:
"""Search for enhanced content with curriculum context""" """Search for enhanced content with curriculum context"""
if not self.pgvector: if not self.pgvector:
return [] return []
try: try:
query_embedding = self.openai_service.generate_embedding(query) query_embedding = self.openai_service.generate_embedding(query)
# PGVector now returns List[SearchResult]
search_results = self.pgvector.search_with_curriculum_context( search_results = self.pgvector.search_with_curriculum_context(
query_embedding=query_embedding, query_embedding=query_embedding,
grade=student_info['grade'], grade=student_info.grade,
subject=subject, subject=subject,
is_arabic=student_info['is_arabic'], is_arabic=student_info.is_arabic,
limit=top_k limit=top_k
) )
relevant_results = [r for r in search_results if r['distance'] < 1.3] if search_results else [] # Access via Dot Notation
relevant_results = [r for r in search_results if r.distance < 1.3] if search_results else []
return relevant_results return relevant_results
except Exception as e: except Exception as e:
......
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from typing import Dict, Any from typing import Dict, Any
from schemas import StudentProfile
from core import StudentNationality, StudyLanguage, Models from core import StudentNationality, StudyLanguage, Models
import logging import logging
...@@ -93,7 +94,7 @@ class QueryHandler: ...@@ -93,7 +94,7 @@ class QueryHandler:
return "لا يمكن الحصول على سياق المحادثة." return "لا يمكن الحصول على سياق المحادثة."
def classify_query_type(self, query: str, student_info: Dict[str, Any], student_id: str) -> str: def classify_query_type(self, query: str, student_info: StudentProfile, student_id: str) -> str:
""" """
Enhanced query classification. It first checks for a specific, rule-based Enhanced query classification. It first checks for a specific, rule-based
pattern for 'game_help' and then falls back to the LLM for other cases. pattern for 'game_help' and then falls back to the LLM for other cases.
...@@ -109,8 +110,8 @@ class QueryHandler: ...@@ -109,8 +110,8 @@ class QueryHandler:
if not self.openai_service.is_available(): if not self.openai_service.is_available():
return "specific_content" return "specific_content"
is_arabic: bool = student_info.get('is_arabic', True) is_arabic = student_info.is_arabic
grade: int = student_info.get('grade', 4) grade = student_info.grade
conversation_context = self.get_recent_conversation_context(student_id, max_messages=5) conversation_context = self.get_recent_conversation_context(student_id, max_messages=5)
...@@ -173,24 +174,18 @@ class QueryHandler: ...@@ -173,24 +174,18 @@ class QueryHandler:
logger.warning(f"Error in query classification: {e}, defaulting to 'specific_content'") logger.warning(f"Error in query classification: {e}, defaulting to 'specific_content'")
return "specific_content" return "specific_content"
def handle_general_chat_query(self, query: str, student_info: Dict[str, Any]) -> str: def handle_general_chat_query(self, query: str, student_info: StudentProfile) -> str:
"""Handle general chat queries using only student information""" """Handle general chat queries using only student information"""
student_name: str = student_info.get('student_name', 'الطالب') student_name = student_info.student_name
grade: int = student_info.get('grade', 4) grade = student_info.grade
nationality_str: str = student_info.get('nationality', 'egyptian') nationality_enum = student_info.nationality
is_arabic: bool = student_info.get('is_arabic', True) is_arabic = student_info.is_arabic
study_lang = "العربية" if is_arabic else "الإنجليزية" study_lang = "العربية" if is_arabic else "الإنجليزية"
# Map nationality string to enum nationality_str = nationality_enum.value if hasattr(nationality_enum, 'value') else str(nationality_enum)
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
}
nationality_enum = nationality_mapping.get(nationality_str.lower().strip(), StudentNationality.EGYPTIAN)
# Get template with fallback
template = GENERAL_CHAT_CONTEXTS.get(nationality_enum) template = GENERAL_CHAT_CONTEXTS.get(nationality_enum)
if not template: if not template:
logger.warning(f"No template found for nationality: {nationality_enum}, using Egyptian fallback") logger.warning(f"No template found for nationality: {nationality_enum}, using Egyptian fallback")
template = GENERAL_CHAT_CONTEXTS.get(StudentNationality.EGYPTIAN) template = GENERAL_CHAT_CONTEXTS.get(StudentNationality.EGYPTIAN)
...@@ -229,18 +224,18 @@ class QueryHandler: ...@@ -229,18 +224,18 @@ class QueryHandler:
""" """
def handle_overview_query(self, student_info: Dict[str, Any], subject: str = "Science") -> str: def handle_overview_query(self, student_info: StudentProfile, subject: str = "Science") -> str:
"""Handle curriculum overview queries using JSON-based data""" """Handle curriculum overview queries using JSON-based data"""
if not self.pgvector: if not self.pgvector:
if student_info['study_language'] == StudyLanguage.ARABIC: if student_info.study_language == StudyLanguage.ARABIC:
return f"عذراً، لا يمكنني عرض المنهج حالياً للصف {student_info['grade']}" return f"عذراً، لا يمكنني عرض المنهج حالياً للصف {student_info.grade}"
else: else:
return f"Sorry, I cannot show the curriculum for Grade {student_info['grade']} right now" return f"Sorry, I cannot show the curriculum for Grade {student_info.grade} right now"
try: try:
return self.pgvector.get_overview_response( return self.pgvector.get_overview_response(
student_info['grade'], student_info.grade,
student_info['is_arabic'], student_info.is_arabic,
subject subject
) )
except Exception as e: except Exception as e:
...@@ -250,7 +245,7 @@ class QueryHandler: ...@@ -250,7 +245,7 @@ class QueryHandler:
else: else:
return f"Sorry, there was an error showing the curriculum for Grade {student_info['grade']}" return f"Sorry, there was an error showing the curriculum for Grade {student_info['grade']}"
def handle_navigation_query(self, query: str, student_info: Dict[str, Any], subject: str = "Science") -> str: def handle_navigation_query(self, query: str, student_info: StudentProfile, subject: str = "Science") -> str:
"""Handle unit/concept navigation queries using JSON structure""" """Handle unit/concept navigation queries using JSON structure"""
if not self.pgvector: if not self.pgvector:
return self.handle_overview_query(student_info, subject) return self.handle_overview_query(student_info, subject)
...@@ -258,8 +253,8 @@ class QueryHandler: ...@@ -258,8 +253,8 @@ class QueryHandler:
try: try:
return self.pgvector.get_unit_navigation_response( return self.pgvector.get_unit_navigation_response(
query, query,
student_info['grade'], student_info.grade,
student_info['is_arabic'], student_info.is_arabic,
subject subject
) )
except Exception as e: except Exception as e:
......
...@@ -5,6 +5,7 @@ from fastapi import HTTPException ...@@ -5,6 +5,7 @@ from fastapi import HTTPException
from services.agent_helpers.agent_prompts import SYSTEM_PROMPTS from services.agent_helpers.agent_prompts import SYSTEM_PROMPTS
import logging import logging
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from schemas import StudentProfile
from core import StudentNationality, StudyLanguage, Models from core import StudentNationality, StudyLanguage, Models
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,18 +38,12 @@ class ResponseGenerator: ...@@ -37,18 +38,12 @@ class ResponseGenerator:
except Exception as e: except Exception as e:
logger.error(f"Error adding message to history for {student_id}: {e}") logger.error(f"Error adding message to history for {student_id}: {e}")
def prepare_system_prompt(self, student_info: Dict) -> str: def prepare_system_prompt(self, student_info: StudentProfile) -> str:
"""Prepare system prompt based on student information""" """Prepare system prompt based on student information"""
student_name = student_info.get('student_name', 'الطالب').split()[0] # Dot notation
study_language = student_info['study_language'] student_name = student_info.student_name.split()[0]
study_language = student_info.study_language
# Map nationality nationality = student_info.nationality
nationality_lower = student_info['nationality'].lower().strip()
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
}
nationality = nationality_mapping.get(nationality_lower, StudentNationality.EGYPTIAN)
# Get appropriate system prompt # Get appropriate system prompt
prompt_key = (nationality, study_language) prompt_key = (nationality, study_language)
...@@ -57,12 +52,12 @@ class ResponseGenerator: ...@@ -57,12 +52,12 @@ class ResponseGenerator:
formatted_base_prompt = base_system_prompt.format( formatted_base_prompt = base_system_prompt.format(
student_name=student_name, student_name=student_name,
grade=student_info['grade'] grade=student_info.grade
) )
# Add Socratic instructions if any # Add Socratic instructions
socratic_instructions = self.pedagogy_service.get_socratic_instructions( socratic_instructions = self.pedagogy_service.get_socratic_instructions(
student_info['grade'], student_info['nationality'] student_info.grade, student_info.nationality
) )
if socratic_instructions: if socratic_instructions:
formatted_base_prompt += f"\n\n{socratic_instructions}" formatted_base_prompt += f"\n\n{socratic_instructions}"
...@@ -87,9 +82,8 @@ class ResponseGenerator: ...@@ -87,9 +82,8 @@ class ResponseGenerator:
student_info = self.db_service.get_student_info(student_id) student_info = self.db_service.get_student_info(student_id)
if not student_info: if not student_info:
raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found") raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found")
student_name = student_info.student_name.split()[0]
student_name = student_info.get('student_name', 'الطالب').split()[0] study_language = student_info.study_language
study_language = student_info['study_language']
# Add user message to DB # Add user message to DB
self.add_message_to_history(student_id, user_message, "user") self.add_message_to_history(student_id, user_message, "user")
......
...@@ -16,6 +16,7 @@ from services.agent_helpers.response_generator import ResponseGenerator ...@@ -16,6 +16,7 @@ from services.agent_helpers.response_generator import ResponseGenerator
from services.agent_helpers.tashkeel_agent import TashkeelAgent from services.agent_helpers.tashkeel_agent import TashkeelAgent
from services.agent_helpers.tashkeel_fixer import apply_fixes, custom_fixes from services.agent_helpers.tashkeel_fixer import apply_fixes, custom_fixes
from services.tts.tts_manager import get_tts_service from services.tts.tts_manager import get_tts_service
from schemas import AgentRequest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,6 +66,7 @@ class AgentService: ...@@ -65,6 +66,7 @@ class AgentService:
def generate_response(self, user_message: str, student_id: str, subject: str = "Science", def generate_response(self, user_message: str, student_id: str, subject: str = "Science",
model: str = Models.chat, temperature: float = 0.3, top_k: int = 3): model: str = Models.chat, temperature: float = 0.3, top_k: int = 3):
""" """
Main response generation method, now handles both string and dictionary responses. Main response generation method, now handles both string and dictionary responses.
""" """
...@@ -86,12 +88,3 @@ class AgentService: ...@@ -86,12 +88,3 @@ class AgentService:
print(f"response: {response}") print(f"response: {response}")
return response return response
def close(self):
"""Close database connection pools"""
if self.pool_handler:
try:
self.pool_handler.close_all()
except Exception as e:
logger.error(f"Error closing connection pools: {e}")
\ No newline at end of file
...@@ -9,6 +9,7 @@ import sys ...@@ -9,6 +9,7 @@ import sys
import os import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudyLanguage from core import StudyLanguage
from schemas import StudentProfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -30,7 +31,7 @@ class ChatDatabaseService: ...@@ -30,7 +31,7 @@ class ChatDatabaseService:
result = cur.fetchone() result = cur.fetchone()
return result["nationality"] if result else None return result["nationality"] if result else None
def get_student_info(self, student_id: str) -> Optional[Dict]: def get_student_info(self, student_id: str) -> Optional[StudentProfile]:
"""Get complete student information with explicit language awareness""" """Get complete student information with explicit language awareness"""
with self.pool_handler.get_connection() as conn: with self.pool_handler.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur: with conn.cursor(cursor_factory=RealDictCursor) as cur:
...@@ -44,16 +45,16 @@ class ChatDatabaseService: ...@@ -44,16 +45,16 @@ class ChatDatabaseService:
) )
result = cur.fetchone() result = cur.fetchone()
if result: if result:
# Convert boolean to explicit language enum
study_language = StudyLanguage.ARABIC if result['language'] else StudyLanguage.ENGLISH study_language = StudyLanguage.ARABIC if result['language'] else StudyLanguage.ENGLISH
return {
'student_id': result['student_id'], return StudentProfile(
'student_name': result['student_name'], student_id=result['student_id'],
'grade': result['grade'], student_name=result['student_name'],
'study_language': study_language, # Explicit language enum grade=result['grade'],
'is_arabic': result['language'], # Keep for backward compatibility study_language=study_language,
'nationality': result['nationality'] nationality=result['nationality'],
} is_arabic=result['language']
)
return None return None
def get_student_grade_and_language(self, student_id: str) -> Optional[Tuple[int, bool, StudyLanguage]]: def get_student_grade_and_language(self, student_id: str) -> Optional[Tuple[int, bool, StudyLanguage]]:
...@@ -192,16 +193,24 @@ class ChatDatabaseService: ...@@ -192,16 +193,24 @@ class ChatDatabaseService:
if not student_info: if not student_info:
return None return None
nationality_desc = "مصري" if student_info['nationality'].lower() == "egyptian" else "سعودي" is_egyptian = False
language_desc = "بالعربي" if student_info['study_language'] == StudyLanguage.ARABIC else "بالإنجليزي" if hasattr(student_info.nationality, 'value'):
# If it's the Enum object
is_egyptian = student_info.nationality.value.lower() == "egyptian"
else:
# If strictly string fallback
is_egyptian = str(student_info.nationality).lower() == "egyptian"
nationality_desc = "مصري" if is_egyptian else "سعودي"
language_desc = "بالعربي" if student_info.study_language == StudyLanguage.ARABIC else "بالإنجليزي"
return { return {
"student_id": student_id, "student_id": student_id,
"student_name": student_info['student_name'], "student_name": student_info.student_name,
"study_language": student_info['study_language'].value, "study_language": student_info.study_language.value,
"nationality": student_info['nationality'], "nationality": str(student_info.nationality.value if hasattr(student_info.nationality, 'value') else student_info.nationality),
"grade": str(student_info['grade']), "grade": str(student_info.grade),
"description": f"طالب {nationality_desc} في الصف {student_info['grade']} يدرس {language_desc}" "description": f"طالب {nationality_desc} في الصف {student_info.grade} يدرس {language_desc}"
} }
def get_students_by_language(self, study_language: StudyLanguage) -> List[Dict]: def get_students_by_language(self, study_language: StudyLanguage) -> List[Dict]:
...@@ -230,3 +239,5 @@ class ChatDatabaseService: ...@@ -230,3 +239,5 @@ class ChatDatabaseService:
} }
for row in results for row in results
] ]
\ No newline at end of file
...@@ -6,6 +6,10 @@ import logging ...@@ -6,6 +6,10 @@ import logging
import json import json
from pgvector.psycopg2 import register_vector from pgvector.psycopg2 import register_vector
from services.connection_pool import ConnectionPool from services.connection_pool import ConnectionPool
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from schemas import SearchResult, CurriculumContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -81,7 +85,7 @@ class PGVectorService: ...@@ -81,7 +85,7 @@ class PGVectorService:
subject: str, subject: str,
is_arabic: bool, is_arabic: bool,
limit: int = 3 limit: int = 3
): ) -> List[SearchResult]: # <--- Update return type hint
"""Enhanced search that includes curriculum position context""" """Enhanced search that includes curriculum position context"""
with self.pool_handler.get_connection() as conn: with self.pool_handler.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur: with conn.cursor(cursor_factory=RealDictCursor) as cur:
...@@ -104,13 +108,27 @@ class PGVectorService: ...@@ -104,13 +108,27 @@ class PGVectorService:
) )
results = cur.fetchall() results = cur.fetchall()
# Enhance results with curriculum context # Convert to Pydantic Models
pydantic_results = []
for result in results: for result in results:
result['curriculum_context'] = self._build_curriculum_context( # Build context dict
ctx_dict = self._build_curriculum_context(
result, curriculum, grade, is_arabic result, curriculum, grade, is_arabic
) )
return results # Create Model
search_result = SearchResult(
id=result['id'],
chunk_text=result['chunk_text'],
unit=result.get('unit'),
concept=result.get('concept'),
lesson=result.get('lesson'),
distance=result['distance'],
curriculum_context=CurriculumContext(**ctx_dict) # Validate nested
)
pydantic_results.append(search_result)
return pydantic_results
def search_flexible_filtered_nearest( def search_flexible_filtered_nearest(
self, self,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment