PGVectorService

parent 2d14ec4a
from .enums import MessageType, ResponseStatus, StudentNationality from .enums import MessageType, ResponseStatus, StudentNationality, Models
from .config import AppConfig from .config import AppConfig
\ No newline at end of file
...@@ -14,3 +14,10 @@ class ResponseStatus(str, Enum): ...@@ -14,3 +14,10 @@ class ResponseStatus(str, Enum):
class StudentNationality(str, Enum): class StudentNationality(str, Enum):
EGYPTIAN = "egyptian" EGYPTIAN = "egyptian"
SAUDI = "saudi" SAUDI = "saudi"
class Models(str, Enum):
chat = "gpt-5-nano"
tts = "gpt-4o-mini-tts"
embedding = "text-embedding-3-small"
transcription = "whisper-1"
...@@ -4,4 +4,5 @@ from .health_service import HealthService ...@@ -4,4 +4,5 @@ from .health_service import HealthService
from .response_service import ResponseService from .response_service import ResponseService
from .response_manager import ResponseManager from .response_manager import ResponseManager
from .openai_service import OpenAIService from .openai_service import OpenAIService
from .agent_service import AgentService from .agent_service import AgentService
\ No newline at end of file from .pgvector_service import PGVectorService
\ No newline at end of file
...@@ -2,11 +2,13 @@ import logging ...@@ -2,11 +2,13 @@ import logging
import os import os
from typing import List, Dict from typing import List, Dict
from fastapi import HTTPException from fastapi import HTTPException
from openai import OpenAI
import sys import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudentNationality
from core import StudentNationality, Models
from services.pgvector_service import PGVectorService
from services.openai_service import OpenAIService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,17 +21,18 @@ SYSTEM_PROMPTS: Dict[StudentNationality, str] = { ...@@ -19,17 +21,18 @@ SYSTEM_PROMPTS: Dict[StudentNationality, str] = {
class AgentService: class AgentService:
"""Service class for handling AI agent conversations using OpenAI GPT""" """Service class for handling AI agent conversations using OpenAI GPT and optional PGVector"""
def __init__(self): def __init__(self, use_pgvector: bool = False):
self.api_key = os.getenv("OPENAI_API_KEY") self.openai_service = OpenAIService()
if not self.api_key: if not self.openai_service.is_available():
logger.warning("Warning: OPENAI_API_KEY not found. Agent service will be disabled.") logger.warning("Warning: OPENAI_API_KEY not found. Agent service will be disabled.")
self.client = None self.client = None
else: else:
self.client = OpenAI(api_key=self.api_key) self.client = self.openai_service.client
self.conversations: Dict[str, List[Dict[str, str]]] = {} self.conversations: Dict[str, List[Dict[str, str]]] = {}
self.pgvector = PGVectorService() if use_pgvector else None
def is_available(self) -> bool: def is_available(self) -> bool:
return self.client is not None return self.client is not None
...@@ -52,17 +55,19 @@ class AgentService: ...@@ -52,17 +55,19 @@ class AgentService:
self, self,
user_message: str, user_message: str,
conversation_id: str = "default", conversation_id: str = "default",
model: str = "gpt-5-nano", model: str = Models.chat,
temperature: float = 1.0, temperature: float = 1.0,
nationality: StudentNationality = StudentNationality.EGYPTIAN nationality: StudentNationality = StudentNationality.EGYPTIAN,
top_k: int = 3
) -> str: ) -> str:
"""Generate a GPT response, optionally enriched with pgvector results"""
if not self.is_available(): if not self.is_available():
raise HTTPException(status_code=500, detail="Agent service not available") raise HTTPException(status_code=500, detail="Agent service not available")
try: try:
self.add_message_to_history(user_message, "user", conversation_id) self.add_message_to_history(user_message, "user", conversation_id)
# 🟢 اختر الـ system prompt المناسب للجنسية # Pick system prompt
system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN]) system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN])
messages = [] messages = []
...@@ -75,6 +80,18 @@ class AgentService: ...@@ -75,6 +80,18 @@ class AgentService:
}) })
messages.extend(conversation_history) messages.extend(conversation_history)
# If pgvector is enabled → enrich with nearest neighbors
if self.pgvector:
query_embedding = self.openai_service.generate_embedding(user_message)
neighbors = self.pgvector.search_nearest(query_embedding, limit=top_k)
if neighbors:
context_message = "Knowledge base search results:\n"
for n in neighbors:
context_message += f"- ID {n['id']} (distance {n['distance']:.4f})\n"
messages.append({"role": "system", "content": context_message})
# Generate AI response
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, model=model,
messages=messages, messages=messages,
...@@ -92,11 +109,20 @@ class AgentService: ...@@ -92,11 +109,20 @@ class AgentService:
logger.error(f"Error generating AI response: {e}") logger.error(f"Error generating AI response: {e}")
raise HTTPException(status_code=500, detail=f"AI response generation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"AI response generation failed: {str(e)}")
def search_similar(self, query_embedding: List[float], top_k: int = 3):
"""Optional nearest neighbor search if PGVector is enabled"""
if not self.pgvector:
raise HTTPException(status_code=400, detail="PGVector service not enabled")
return self.pgvector.search_nearest(query_embedding, limit=top_k)
# ----------------- Suggested Test ----------------- # ----------------- Suggested Test -----------------
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
agent = AgentService()
# Agent with pgvector enabled
agent = AgentService(use_pgvector=True)
if agent.is_available(): if agent.is_available():
reply = agent.generate_response("هو يعني إيه ذَرّة؟", model="gpt-5-nano", nationality=StudentNationality.EGYPTIAN) reply = agent.generate_response("هو يعني إيه ذَرّة؟", model="gpt-5-nano", nationality=StudentNationality.EGYPTIAN)
print("AI:", reply) print("AI:", reply)
......
...@@ -2,13 +2,16 @@ import os ...@@ -2,13 +2,16 @@ import os
import time import time
import tempfile import tempfile
import io import io
from typing import Optional from typing import Optional, List
from fastapi import HTTPException from fastapi import HTTPException
from openai import OpenAI from openai import OpenAI
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import Models
class OpenAIService: class OpenAIService:
"""Service class for handling OpenAI API operations (TTS and Whisper)""" """Service class for handling OpenAI API operations (TTS, Whisper, Embeddings)"""
def __init__(self): def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY") self.api_key = os.getenv("OPENAI_API_KEY")
...@@ -22,40 +25,24 @@ class OpenAIService: ...@@ -22,40 +25,24 @@ class OpenAIService:
"""Check if OpenAI service is available""" """Check if OpenAI service is available"""
return self.client is not None return self.client is not None
# ------------------- Whisper -------------------
def transcribe_audio(self, file_content: bytes, filename: str, language: Optional[str] = "ar") -> str: def transcribe_audio(self, file_content: bytes, filename: str, language: Optional[str] = "ar") -> str:
""" """Transcribe audio using OpenAI Whisper"""
Transcribe audio using OpenAI Whisper
Args:
file_content: Audio file content as bytes
filename: Original filename for context
language: Language code (optional, defaults to Arabic)
Returns:
Transcribed text
Raises:
HTTPException: If transcription fails or service unavailable
"""
if not self.is_available(): if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available") raise HTTPException(status_code=500, detail="OpenAI service not available")
try: try:
# Create file-like object for the API
audio_file = io.BytesIO(file_content) audio_file = io.BytesIO(file_content)
audio_file.name = filename audio_file.name = filename
print(f"Transcribing audio: {filename}") print(f"Transcribing audio: {filename}")
# Call Whisper API
transcript = self.client.audio.transcriptions.create( transcript = self.client.audio.transcriptions.create(
model="whisper-1", model=Models.transcription,
file=audio_file, file=audio_file,
language=language if language else None # Auto-detect if None language=language if language else None
) )
transcribed_text = transcript.text.strip() transcribed_text = transcript.text.strip()
if not transcribed_text: if not transcribed_text:
raise ValueError("Empty transcription result") raise ValueError("Empty transcription result")
...@@ -66,36 +53,22 @@ class OpenAIService: ...@@ -66,36 +53,22 @@ class OpenAIService:
print(f"Error during transcription: {e}") print(f"Error during transcription: {e}")
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
def generate_speech(self, text: str, voice: str = "alloy", model: str = "tts-1") -> str: # ------------------- TTS -------------------
""" def generate_speech(self, text: str, voice: str = "alloy") -> str:
Generate speech from text using OpenAI TTS """Generate speech from text using OpenAI TTS"""
Args:
text: Text to convert to speech
voice: Voice to use (alloy, echo, fable, onyx, nova, shimmer)
model: TTS model to use (tts-1 or tts-1-hd)
Returns:
Path to temporary file containing the generated audio
Raises:
HTTPException: If TTS generation fails or service unavailable
"""
if not self.is_available(): if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available") raise HTTPException(status_code=500, detail="OpenAI service not available")
temp_file_path = None temp_file_path = None
try: try:
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file_path = temp_file.name temp_file_path = temp_file.name
temp_file.close() temp_file.close()
print(f"Generating TTS audio: {text[:50]}...") print(f"Generating TTS audio: {text[:50]}...")
# Generate audio using OpenAI TTS
with self.client.audio.speech.with_streaming_response.create( with self.client.audio.speech.with_streaming_response.create(
model=model, model=Models.tts,
voice=voice, voice=voice,
input=text, input=text,
response_format="mp3" response_format="mp3"
...@@ -106,13 +79,41 @@ class OpenAIService: ...@@ -106,13 +79,41 @@ class OpenAIService:
return temp_file_path return temp_file_path
except Exception as e: except Exception as e:
# Clean up temp file on error
if temp_file_path and os.path.exists(temp_file_path): if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path) os.unlink(temp_file_path)
print(f"Error during TTS generation: {e}") print(f"Error during TTS generation: {e}")
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
# ------------------- Embeddings -------------------
def generate_embedding(self, text: str) -> List[float]:
"""
Generate an embedding vector for input text.
Args:
text: Input string
model: Embedding model (default: text-embedding-3-small)
Returns:
List[float]: Embedding vector
"""
if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available")
try:
response = self.client.embeddings.create(
model=Models.embedding,
input=text
)
embedding = response.data[0].embedding
if not embedding:
raise ValueError("Empty embedding generated")
return embedding
except Exception as e:
print(f"Error during embedding generation: {e}")
raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")
# ------------------- Utils -------------------
def cleanup_temp_file(self, file_path: str) -> None: def cleanup_temp_file(self, file_path: str) -> None:
"""Clean up temporary file""" """Clean up temporary file"""
if file_path and os.path.exists(file_path): if file_path and os.path.exists(file_path):
...@@ -120,4 +121,4 @@ class OpenAIService: ...@@ -120,4 +121,4 @@ class OpenAIService:
os.unlink(file_path) os.unlink(file_path)
print(f"Cleaned up temporary file: {file_path}") print(f"Cleaned up temporary file: {file_path}")
except Exception as e: except Exception as e:
print(f"Warning: Could not clean up temp file {file_path}: {e}") print(f"Warning: Could not clean up temp file {file_path}: {e}")
\ No newline at end of file
import os
import psycopg2
from psycopg2.extras import RealDictCursor
class PGVectorService:
"""Service for managing embeddings with PostgreSQL pgvector"""
def __init__(self):
self.conn = psycopg2.connect(
host=os.getenv("POSTGRES_HOST", "postgres"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
dbname=os.getenv("POSTGRES_DB"),
)
def insert_embedding(self, id: int, embedding: list):
"""Insert or update an embedding"""
with self.conn.cursor() as cur:
cur.execute(
"""
INSERT INTO embeddings_table (id, embedding)
VALUES (%s, %s)
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;
""",
(id, embedding),
)
self.conn.commit()
def search_nearest(self, query_embedding: list, limit: int = 3):
"""Search nearest embeddings using cosine distance (<-> operator)"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT id, embedding, embedding <-> %s AS distance
FROM embeddings_table
ORDER BY embedding <-> %s
LIMIT %s;
""",
(query_embedding, query_embedding, limit),
)
return cur.fetchall()
def close(self):
if self.conn:
self.conn.close()
import os
import psycopg2
from psycopg2.extras import RealDictCursor
# Read credentials from environment variables
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
def get_db_connection():
conn = psycopg2.connect(
host=POSTGRES_HOST,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DB
)
return conn
# Example usage
if __name__ == "__main__":
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
# Get all table names in the public schema
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE';
""")
tables = cur.fetchall()
for table in tables:
table_name = table['table_name']
print(f"Contents of table: {table_name}")
# Fetch all rows in the table
cur.execute(f"SELECT * FROM {table_name};")
rows = cur.fetchall()
for row in rows:
print(row)
print("\n")
cur.close()
conn.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment