handle vector database

parent 90af444c
......@@ -8,7 +8,7 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . .
#just keep the container running without doing anything
#CMD ["sh", "-c", "while :; do sleep 10; done"]
CMD ["sh", "-c", "while :; do sleep 10; done"]
#run the app automatically when the container starts
CMD ["python", "main.py"]
#CMD ["python", "main.py"]
......@@ -6,7 +6,9 @@ schema_sql = """
CREATE TABLE IF NOT EXISTS students (
id SERIAL PRIMARY KEY,
student_id VARCHAR(50) UNIQUE NOT NULL,
grade INTEGER,
student_name VARCHAR(100),
grade VARCHAR(20),
language BOOLEAN,
nationality VARCHAR(20) NOT NULL DEFAULT 'EGYPTIAN'
);
......@@ -26,11 +28,11 @@ CREATE INDEX IF NOT EXISTS idx_chat_history_created_at ON chat_history(created_a
CREATE INDEX IF NOT EXISTS idx_students_nationality ON students(nationality);
-- Insert dummy data for testing
INSERT INTO students (student_id, grade, nationality) VALUES
('student_001', 3, 'EGYPTIAN'),
('student_002', 4, 'SAUDI'),
('student_003', 2, 'EGYPTIAN'),
('student_004', 5, 'SAUDI')
INSERT INTO students (student_id, student_name, grade, language, nationality) VALUES
('student_001', 'Ahmed Ali', 'prime4', TRUE, 'EGYPTIAN'),
('student_002', 'Sara Hassan', 'prime6', FALSE, 'SAUDI'),
('student_003', 'Mona Adel', 'prime5', TRUE, 'EGYPTIAN'),
('student_004', 'Omar Youssef', 'prime6', FALSE, 'SAUDI')
ON CONFLICT (student_id) DO NOTHING;
"""
......@@ -57,7 +59,7 @@ conn.autocommit = True
with conn.cursor() as cur:
# Drop all existing tables (uncomment if needed)
#cur.execute(drop_all_tables_sql)
cur.execute(drop_all_tables_sql)
cur.execute(schema_sql)
# Verifications: Select from students and chat_history tables
......
import logging
import os
from typing import List, Dict
from typing import List, Dict, Optional
from fastapi import HTTPException
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
......@@ -15,24 +15,22 @@ logger = logging.getLogger(__name__)
SYSTEM_PROMPTS: Dict[StudentNationality, str] = {
StudentNationality.EGYPTIAN: """
إنت مُدرّس كيميا لطفل في ابتدائي.
StudentNationality.EGYPTIAN: """
إنت مُدرّس لطفل في ابتدائي.
رد باللهجة المصريّة الطبيعيّة.
خلي الكلام بسيط وواضح، كفاية يوصّل الفكرة من غير تطويل.
الجمل تكون قصيرة نسبيًّا، بس مش ناقصة.
استخدم التشكيل بس على الكلمات اللي ممكن الـTTS ينطقها غلط.
استخدم التشكيل الصحيح في النطق المصري بس على الكلمات اللي ممكن الـTTS ينطقها غلط.
اشرح المعلومة خطوة خطوة من غير تكرار.
ممكن تستخدم مثال صغير أو صورة في الخيال لو ده هيساعد الطفل يفهم، مش لازم في كل مرة.
خليك بتحكي كأنها لقطة من الحياة، مش شرح كتاب.
الرموز الكيميائية زي H2O أو CO2 تكتب زي ما هي.
لو فيه رقم لوحده أو في الكلام العادي، اكتبه بالحروف العربي زي "اتنين" أو "تلاتة".
الهدف: رد قصير ومباشر يعلّم، من غير زيادة كلام ولا سطحية.
"""
,
""",
StudentNationality.SAUDI: """
إنت معلّم كيميا لطفل في ابتدائي.
إنت معلّم لطفل في ابتدائي.
رد باللهجة السعوديّة البسيطة.
خل الكلام واضح وقصير يكفي يوصّل الفكرة.
الجمل تكون قصيرة نسبيًّا، لكن لا تكون ناقصة.
......@@ -44,7 +42,6 @@ StudentNationality.EGYPTIAN: """
لو فيه رقم لوحده أو في النص العادي، اكتبه بالحروف العربي زي "اثنين" أو "ثلاثة".
الهدف: رد بسيط وذكي يوضّح الفكرة من غير زيادة ولا سطحية.
"""
}
......@@ -76,55 +73,43 @@ class AgentService:
# Limit history to prevent growth
self.db_service.limit_history(student_id, max_messages=38)
def get_available_subjects(self, student_id: str) -> List[str]:
"""Get available subjects for the student based on their grade and language"""
if not self.pgvector:
return []
student_info = self.db_service.get_student_info(student_id)
if not student_info:
return []
return self.pgvector.get_subjects_by_grade_and_language(
student_info['grade'],
student_info['is_arabic']
)
def generate_response(
self,
user_message: str,
student_id: str,
subject: str = "Science",
model: str = Models.chat,
temperature: float = 1.0,
top_k: int = 3
) -> str:
"""Generate AI response using database memory"""
"""Generate AI response using database memory with subject filtering"""
if not self.is_available():
raise HTTPException(status_code=500, detail="Agent service not available")
try:
# Get student nationality from database
nationality_str = self.db_service.get_student_nationality(student_id)
if not nationality_str:
# Get complete student information from database
student_info = self.db_service.get_student_info(student_id)
if not student_info:
raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found")
logger.info(f"Retrieved nationality from DB: '{nationality_str}' for student: {student_id}")
# Debug the enum
print(f"DEBUG - StudentNationality enum:")
for enum_item in StudentNationality:
print(f" {enum_item.name} = '{enum_item.value}'")
# Debug the conversion
nationality_lower = nationality_str.lower().strip()
print(f"DEBUG - DB value: '{nationality_str}' -> lowercase: '{nationality_lower}'")
print(f"DEBUG - Is 'saudi' in enum values? {'saudi' in [e.value for e in StudentNationality]}")
print(f"DEBUG - Direct enum creation test:")
# Test direct enum creation
try:
test_saudi = StudentNationality('saudi')
print(f" StudentNationality('saudi') = {test_saudi}")
except Exception as e:
print(f" StudentNationality('saudi') failed: {e}")
try:
test_egyptian = StudentNationality('egyptian')
print(f" StudentNationality('egyptian') = {test_egyptian}")
except Exception as e:
print(f" StudentNationality('egyptian') failed: {e}")
# Convert string to StudentNationality enum
nationality_lower = nationality_str.lower().strip()
print(f"DEBUG - Looking for nationality: '{nationality_lower}'")
# Try explicit mapping first
logger.info(f"Retrieved student info from DB: {student_info} for student: {student_id}")
# Convert nationality string to StudentNationality enum
nationality_lower = student_info['nationality'].lower().strip()
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
......@@ -132,9 +117,9 @@ class AgentService:
if nationality_lower in nationality_mapping:
nationality = nationality_mapping[nationality_lower]
logger.info(f"Successfully mapped '{nationality_str}' to {nationality}")
logger.info(f"Successfully mapped '{student_info['nationality']}' to {nationality}")
else:
logger.warning(f"Unknown nationality '{nationality_str}' ('{nationality_lower}') for student {student_id}, defaulting to EGYPTIAN")
logger.warning(f"Unknown nationality '{student_info['nationality']}' for student {student_id}, defaulting to EGYPTIAN")
nationality = StudentNationality.EGYPTIAN
# Add user message to database
......@@ -143,10 +128,11 @@ class AgentService:
# Get conversation history from database
conversation_history = self.get_conversation_history(student_id)
# Pick system prompt using the enum value
system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN])
logger.info(f"Using nationality: {nationality} for student: {student_id}")
print(f"DEBUG - Selected system_prompt: {system_prompt}") # Debug print
# Create subject-specific system prompt
base_system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN])
subject_specific_prompt = f"{base_system_prompt}\n\nإنت بتدرّس مادة {subject} للطفل ده."
logger.info(f"Using nationality: {nationality} and subject: {subject} for student: {student_id}")
# Prepare messages
messages = []
......@@ -155,26 +141,45 @@ class AgentService:
has_system_message = conversation_history and conversation_history[0].get("role") == "system"
if not has_system_message:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "system", "content": subject_specific_prompt})
# Add system message to database
self.add_message_to_history(student_id, system_prompt, "system")
self.add_message_to_history(student_id, subject_specific_prompt, "system")
# Add conversation history
messages.extend(conversation_history)
# Optional pgvector enrichment
# Enhanced pgvector enrichment with filtering
if self.pgvector:
try:
query_embedding = self.openai_service.generate_embedding(user_message)
neighbors = self.pgvector.search_nearest(query_embedding, limit=top_k)
# Search with filtering based on student info and subject
neighbors = self.pgvector.search_filtered_nearest(
query_embedding=query_embedding,
grade=student_info['grade'],
subject=subject,
is_arabic=student_info['is_arabic'],
limit=top_k
)
if neighbors:
context_message = "Knowledge base search results:\n"
for n in neighbors:
context_message += f"- ID {n['id']} (distance {n['distance']:.4f})\n"
context_message = f"معلومات من المنهج لمادة {subject} للصف {student_info['grade']}:\n"
for i, n in enumerate(neighbors, 1):
unit_info = f" - الوحدة: {n['unit']}" if n['unit'] else ""
concept_info = f" - المفهوم: {n['concept']}" if n['concept'] else ""
lesson_info = f" - الدرس: {n['lesson']}" if n['lesson'] else ""
context_message += f"\n{i}. {unit_info}{concept_info}{lesson_info}\n"
context_message += f"المحتوى: {n['chunk_text'][:200]}...\n"
context_message += f"(درجة التشابه: {n['distance']:.3f})\n"
messages.append({"role": "system", "content": context_message})
logger.info(f"Added {len(neighbors)} filtered knowledge base results for subject: {subject}")
else:
logger.info(f"No relevant content found for subject: {subject}, grade: {student_info['grade']}, Arabic: {student_info['is_arabic']}")
except Exception as e:
logger.warning(f"Error using pgvector: {e}")
logger.warning(f"Error using pgvector with filtering: {e}")
# Generate AI response
response = self.client.chat.completions.create(
......@@ -196,31 +201,86 @@ class AgentService:
logger.error(f"Error generating AI response: {e}")
raise HTTPException(status_code=500, detail=f"AI response generation failed: {str(e)}")
def search_similar(self, query_embedding: List[float], top_k: int = 3):
"""Optional pgvector search"""
def search_similar(self, query_embedding: List[float], student_id: str,
subject: str = "chemistry", top_k: int = 3):
"""Search similar content with student-specific filtering"""
if not self.pgvector:
raise HTTPException(status_code=400, detail="PGVector service not enabled")
return self.pgvector.search_nearest(query_embedding, limit=top_k)
student_info = self.db_service.get_student_info(student_id)
if not student_info:
raise HTTPException(status_code=404, detail=f"Student with ID {student_id} not found")
return self.pgvector.search_filtered_nearest(
query_embedding=query_embedding,
grade=student_info['grade'],
subject=subject,
is_arabic=student_info['is_arabic'],
limit=top_k
)
def update_student_subject_context(self, student_id: str, subject: str):
"""Update the system message for a new subject"""
try:
student_info = self.db_service.get_student_info(student_id)
if not student_info:
return False
# Clear existing history to reset context
self.db_service.clear_history(student_id)
# Set new system message with subject
nationality_lower = student_info['nationality'].lower().strip()
nationality_mapping = {
'egyptian': StudentNationality.EGYPTIAN,
'saudi': StudentNationality.SAUDI
}
nationality = nationality_mapping.get(nationality_lower, StudentNationality.EGYPTIAN)
base_system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN])
subject_specific_prompt = f"{base_system_prompt}\n\nإنت بتدرّس مادة {subject} للطفل ده."
self.add_message_to_history(student_id, subject_specific_prompt, "system")
logger.info(f"Updated subject context to {subject} for student {student_id} ({student_info['student_name']})")
return True
except Exception as e:
logger.error(f"Error updating subject context: {e}")
return False
def close(self):
"""Close database connection"""
"""Close database connections"""
if self.db_service:
self.db_service.close()
if self.pgvector:
self.pgvector.close()
# ----------------- Test -----------------
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
agent = AgentService(use_pgvector=False)
agent = AgentService(use_pgvector=True)
if agent.is_available():
try:
# Test with chemistry (default)
reply = agent.generate_response(
"هو يعني إيه ذَرّة؟",
student_id="student_001"
student_id="student_001",
subject="chemistry"
)
print("AI:", reply)
print("AI (Chemistry):", reply)
# Test with math
reply = agent.generate_response(
"إيه هو الجمع؟",
student_id="student_001",
subject="math"
)
print("AI (Math):", reply)
except Exception as e:
print(f"Test failed: {e}")
finally:
......
import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
import logging
logger = logging.getLogger(__name__)
......@@ -28,6 +28,39 @@ class ChatDatabaseService:
result = cur.fetchone()
return result["nationality"] if result else None
def get_student_info(self, student_id: str) -> Optional[Dict]:
"""Get complete student information from database"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT student_id, grade, language, nationality
FROM students
WHERE student_id = %s
""",
(student_id,)
)
result = cur.fetchone()
if result:
return {
'student_id': result['student_id'],
'grade': result['grade'],
'is_arabic': result['language'], # Convert language boolean to is_arabic
'nationality': result['nationality']
}
return None
def get_student_grade_and_language(self, student_id: str) -> Optional[Tuple[int, bool]]:
"""Get student grade and language preference"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"SELECT grade, language FROM students WHERE student_id = %s",
(student_id,)
)
result = cur.fetchone()
if result:
return (result["grade"], result["language"])
return None
def get_chat_history(self, student_id: str, limit: int = 20) -> List[Dict[str, str]]:
"""Get chat history for a student, returns in chronological order"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
......@@ -85,6 +118,51 @@ class ChatDatabaseService:
)
self.conn.commit()
def update_student_info(self, student_id: str, grade: Optional[int] = None,
language: Optional[bool] = None, nationality: Optional[str] = None):
"""Update student information"""
updates = []
params = []
if grade is not None:
updates.append("grade = %s")
params.append(grade)
if language is not None:
updates.append("language = %s")
params.append(language)
if nationality is not None:
updates.append("nationality = %s")
params.append(nationality)
if updates:
params.append(student_id)
with self.conn.cursor() as cur:
cur.execute(
f"""
UPDATE students
SET {', '.join(updates)}
WHERE student_id = %s
""",
params
)
self.conn.commit()
def create_student(self, student_id: str, student_name: str, grade: str,
language: bool, nationality: str = 'EGYPTIAN'):
"""Create a new student record"""
with self.conn.cursor() as cur:
cur.execute(
"""
INSERT INTO students (student_id, student_name, grade, language, nationality)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (student_id) DO NOTHING;
""",
(student_id, student_name, grade, language, nationality)
)
self.conn.commit()
def close(self):
if self.conn:
self.conn.close()
\ No newline at end of file
import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Optional
class PGVectorService:
......@@ -41,6 +42,106 @@ class PGVectorService:
)
return cur.fetchall()
def search_filtered_nearest(
self,
query_embedding: list,
grade: int,
subject: str,
is_arabic: bool,
limit: int = 3
):
"""Search nearest embeddings with filtering by grade, subject, and language"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT id, grade, subject, unit, concept, lesson, chunk_text,
is_arabic, embedding <-> %s AS distance
FROM educational_chunks
WHERE grade = %s
AND subject ILIKE %s
AND is_arabic = %s
ORDER BY embedding <-> %s
LIMIT %s;
""",
(query_embedding, grade, f"%{subject}%", is_arabic, query_embedding, limit),
)
return cur.fetchall()
def search_flexible_filtered_nearest(
self,
query_embedding: list,
grade: Optional[int] = None,
subject: Optional[str] = None,
is_arabic: Optional[bool] = None,
limit: int = 3
):
"""Search nearest embeddings with flexible filtering"""
conditions = []
params = [query_embedding]
if grade is not None:
conditions.append("grade = %s")
params.append(grade)
if subject is not None:
conditions.append("subject ILIKE %s")
params.append(f"%{subject}%")
if is_arabic is not None:
conditions.append("is_arabic = %s")
params.append(is_arabic)
where_clause = ""
if conditions:
where_clause = "WHERE " + " AND ".join(conditions)
# Add query_embedding again for ORDER BY
params.append(query_embedding)
params.append(limit)
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
f"""
SELECT id, grade, subject, unit, concept, lesson, chunk_text,
is_arabic, embedding <-> %s AS distance
FROM educational_chunks
{where_clause}
ORDER BY embedding <-> %s
LIMIT %s;
""",
params
)
return cur.fetchall()
def get_subjects_by_grade_and_language(self, grade: str, is_arabic: bool) -> List[str]:
"""Get available subjects for a specific grade and language"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
# Extract numeric part from grade string
grade_number = ''.join(filter(str.isdigit, grade)) if grade else None
if grade_number:
cur.execute(
"""
SELECT DISTINCT subject
FROM educational_chunks
WHERE grade = %s AND is_arabic = %s
ORDER BY subject;
""",
(int(grade_number), is_arabic)
)
else:
# Fallback if grade parsing fails
cur.execute(
"""
SELECT DISTINCT subject
FROM educational_chunks
WHERE is_arabic = %s
ORDER BY subject;
""",
(is_arabic,)
)
return [row['subject'] for row in cur.fetchall()]
def close(self):
if self.conn:
self.conn.close()
self.conn.close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment