PGVectorService

parent 2d14ec4a
from .enums import MessageType, ResponseStatus, StudentNationality
from .enums import MessageType, ResponseStatus, StudentNationality, Models
from .config import AppConfig
\ No newline at end of file
......@@ -14,3 +14,10 @@ class ResponseStatus(str, Enum):
class StudentNationality(str, Enum):
EGYPTIAN = "egyptian"
SAUDI = "saudi"
class Models(str, Enum):
chat = "gpt-5-nano"
tts = "gpt-4o-mini-tts"
embedding = "text-embedding-3-small"
transcription = "whisper-1"
......@@ -4,4 +4,5 @@ from .health_service import HealthService
from .response_service import ResponseService
from .response_manager import ResponseManager
from .openai_service import OpenAIService
from .agent_service import AgentService
\ No newline at end of file
from .agent_service import AgentService
from .pgvector_service import PGVectorService
\ No newline at end of file
......@@ -2,11 +2,13 @@ import logging
import os
from typing import List, Dict
from fastapi import HTTPException
from openai import OpenAI
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import StudentNationality
from core import StudentNationality, Models
from services.pgvector_service import PGVectorService
from services.openai_service import OpenAIService
logger = logging.getLogger(__name__)
......@@ -19,17 +21,18 @@ SYSTEM_PROMPTS: Dict[StudentNationality, str] = {
class AgentService:
"""Service class for handling AI agent conversations using OpenAI GPT"""
"""Service class for handling AI agent conversations using OpenAI GPT and optional PGVector"""
def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY")
if not self.api_key:
def __init__(self, use_pgvector: bool = False):
self.openai_service = OpenAIService()
if not self.openai_service.is_available():
logger.warning("Warning: OPENAI_API_KEY not found. Agent service will be disabled.")
self.client = None
else:
self.client = OpenAI(api_key=self.api_key)
self.client = self.openai_service.client
self.conversations: Dict[str, List[Dict[str, str]]] = {}
self.pgvector = PGVectorService() if use_pgvector else None
def is_available(self) -> bool:
return self.client is not None
......@@ -52,17 +55,19 @@ class AgentService:
self,
user_message: str,
conversation_id: str = "default",
model: str = "gpt-5-nano",
model: str = Models.chat,
temperature: float = 1.0,
nationality: StudentNationality = StudentNationality.EGYPTIAN
nationality: StudentNationality = StudentNationality.EGYPTIAN,
top_k: int = 3
) -> str:
"""Generate a GPT response, optionally enriched with pgvector results"""
if not self.is_available():
raise HTTPException(status_code=500, detail="Agent service not available")
try:
self.add_message_to_history(user_message, "user", conversation_id)
# 🟢 اختر الـ system prompt المناسب للجنسية
# Pick system prompt
system_prompt = SYSTEM_PROMPTS.get(nationality, SYSTEM_PROMPTS[StudentNationality.EGYPTIAN])
messages = []
......@@ -75,6 +80,18 @@ class AgentService:
})
messages.extend(conversation_history)
# If pgvector is enabled → enrich with nearest neighbors
if self.pgvector:
query_embedding = self.openai_service.generate_embedding(user_message)
neighbors = self.pgvector.search_nearest(query_embedding, limit=top_k)
if neighbors:
context_message = "Knowledge base search results:\n"
for n in neighbors:
context_message += f"- ID {n['id']} (distance {n['distance']:.4f})\n"
messages.append({"role": "system", "content": context_message})
# Generate AI response
response = self.client.chat.completions.create(
model=model,
messages=messages,
......@@ -92,11 +109,20 @@ class AgentService:
logger.error(f"Error generating AI response: {e}")
raise HTTPException(status_code=500, detail=f"AI response generation failed: {str(e)}")
def search_similar(self, query_embedding: List[float], top_k: int = 3):
"""Optional nearest neighbor search if PGVector is enabled"""
if not self.pgvector:
raise HTTPException(status_code=400, detail="PGVector service not enabled")
return self.pgvector.search_nearest(query_embedding, limit=top_k)
# ----------------- Suggested Test -----------------
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
agent = AgentService()
# Agent with pgvector enabled
agent = AgentService(use_pgvector=True)
if agent.is_available():
reply = agent.generate_response("هو يعني إيه ذَرّة؟", model="gpt-5-nano", nationality=StudentNationality.EGYPTIAN)
print("AI:", reply)
......
......@@ -2,13 +2,16 @@ import os
import time
import tempfile
import io
from typing import Optional
from typing import Optional, List
from fastapi import HTTPException
from openai import OpenAI
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core import Models
class OpenAIService:
"""Service class for handling OpenAI API operations (TTS and Whisper)"""
"""Service class for handling OpenAI API operations (TTS, Whisper, Embeddings)"""
def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY")
......@@ -22,40 +25,24 @@ class OpenAIService:
"""Check if OpenAI service is available"""
return self.client is not None
# ------------------- Whisper -------------------
def transcribe_audio(self, file_content: bytes, filename: str, language: Optional[str] = "ar") -> str:
"""
Transcribe audio using OpenAI Whisper
Args:
file_content: Audio file content as bytes
filename: Original filename for context
language: Language code (optional, defaults to Arabic)
Returns:
Transcribed text
Raises:
HTTPException: If transcription fails or service unavailable
"""
"""Transcribe audio using OpenAI Whisper"""
if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available")
try:
# Create file-like object for the API
audio_file = io.BytesIO(file_content)
audio_file.name = filename
print(f"Transcribing audio: {filename}")
# Call Whisper API
transcript = self.client.audio.transcriptions.create(
model="whisper-1",
model=Models.transcription,
file=audio_file,
language=language if language else None # Auto-detect if None
language=language if language else None
)
transcribed_text = transcript.text.strip()
if not transcribed_text:
raise ValueError("Empty transcription result")
......@@ -66,36 +53,22 @@ class OpenAIService:
print(f"Error during transcription: {e}")
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
def generate_speech(self, text: str, voice: str = "alloy", model: str = "tts-1") -> str:
"""
Generate speech from text using OpenAI TTS
Args:
text: Text to convert to speech
voice: Voice to use (alloy, echo, fable, onyx, nova, shimmer)
model: TTS model to use (tts-1 or tts-1-hd)
Returns:
Path to temporary file containing the generated audio
Raises:
HTTPException: If TTS generation fails or service unavailable
"""
# ------------------- TTS -------------------
def generate_speech(self, text: str, voice: str = "alloy") -> str:
"""Generate speech from text using OpenAI TTS"""
if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available")
temp_file_path = None
try:
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file_path = temp_file.name
temp_file.close()
print(f"Generating TTS audio: {text[:50]}...")
# Generate audio using OpenAI TTS
with self.client.audio.speech.with_streaming_response.create(
model=model,
model=Models.tts,
voice=voice,
input=text,
response_format="mp3"
......@@ -106,13 +79,41 @@ class OpenAIService:
return temp_file_path
except Exception as e:
# Clean up temp file on error
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
print(f"Error during TTS generation: {e}")
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
# ------------------- Embeddings -------------------
def generate_embedding(self, text: str) -> List[float]:
"""
Generate an embedding vector for input text.
Args:
text: Input string
model: Embedding model (default: text-embedding-3-small)
Returns:
List[float]: Embedding vector
"""
if not self.is_available():
raise HTTPException(status_code=500, detail="OpenAI service not available")
try:
response = self.client.embeddings.create(
model=Models.embedding,
input=text
)
embedding = response.data[0].embedding
if not embedding:
raise ValueError("Empty embedding generated")
return embedding
except Exception as e:
print(f"Error during embedding generation: {e}")
raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")
# ------------------- Utils -------------------
def cleanup_temp_file(self, file_path: str) -> None:
"""Clean up temporary file"""
if file_path and os.path.exists(file_path):
......@@ -120,4 +121,4 @@ class OpenAIService:
os.unlink(file_path)
print(f"Cleaned up temporary file: {file_path}")
except Exception as e:
print(f"Warning: Could not clean up temp file {file_path}: {e}")
\ No newline at end of file
print(f"Warning: Could not clean up temp file {file_path}: {e}")
import os
import psycopg2
from psycopg2.extras import RealDictCursor
class PGVectorService:
"""Service for managing embeddings with PostgreSQL pgvector"""
def __init__(self):
self.conn = psycopg2.connect(
host=os.getenv("POSTGRES_HOST", "postgres"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
dbname=os.getenv("POSTGRES_DB"),
)
def insert_embedding(self, id: int, embedding: list):
"""Insert or update an embedding"""
with self.conn.cursor() as cur:
cur.execute(
"""
INSERT INTO embeddings_table (id, embedding)
VALUES (%s, %s)
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;
""",
(id, embedding),
)
self.conn.commit()
def search_nearest(self, query_embedding: list, limit: int = 3):
"""Search nearest embeddings using cosine distance (<-> operator)"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"""
SELECT id, embedding, embedding <-> %s AS distance
FROM embeddings_table
ORDER BY embedding <-> %s
LIMIT %s;
""",
(query_embedding, query_embedding, limit),
)
return cur.fetchall()
def close(self):
if self.conn:
self.conn.close()
import os
import psycopg2
from psycopg2.extras import RealDictCursor
# Read credentials from environment variables
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "postgres")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
def get_db_connection():
conn = psycopg2.connect(
host=POSTGRES_HOST,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DB
)
return conn
# Example usage
if __name__ == "__main__":
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
# Get all table names in the public schema
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE';
""")
tables = cur.fetchall()
for table in tables:
table_name = table['table_name']
print(f"Contents of table: {table_name}")
# Fetch all rows in the table
cur.execute(f"SELECT * FROM {table_name};")
rows = cur.fetchall()
for row in rows:
print(row)
print("\n")
cur.close()
conn.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment