import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Optional
import logging
from pgvector.psycopg2 import register_vector
from services.connection_pool import ConnectionPool

logger = logging.getLogger(__name__)


class PGVectorService:
    """Service for managing embeddings with PostgreSQL pgvector using a shared, robust connection pool"""

    def __init__(self, pool_handler: 'ConnectionPool'):
        self.pool_handler = pool_handler
        # Test connection and register vector type
        with self.pool_handler.get_connection() as conn:
            register_vector(conn)

    def insert_embedding(self, id: int, embedding: list):
        """Insert or update an embedding"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO embeddings_table (id, embedding)
                    VALUES (%s, %s)
                    ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;
                    """,
                    (id, embedding),
                )
                conn.commit()

    def search_nearest(self, query_embedding: list, limit: int = 3):
        """Search nearest embeddings using cosine distance (<-> operator)"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT id, embedding, embedding <-> %s AS distance
                    FROM embeddings_table
                    ORDER BY embedding <-> %s
                    LIMIT %s;
                    """,
                    (query_embedding, query_embedding, limit),
                )
                return cur.fetchall()

    def search_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: int,
        subject: str,
        is_arabic: bool,
        limit: int = 3
    ):
        """Search nearest embeddings with filtering by grade, subject, and language"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                           is_arabic, embedding <-> %s::vector AS distance
                    FROM educational_chunks
                    WHERE grade = %s 
                      AND subject ILIKE %s 
                      AND is_arabic = %s
                    ORDER BY embedding <-> %s::vector
                    LIMIT %s;
                    """,
                    (query_embedding, grade, f"%{subject}%", is_arabic, query_embedding, limit),
                )
                return cur.fetchall()

    def search_flexible_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: Optional[int] = None,
        subject: Optional[str] = None,
        is_arabic: Optional[bool] = None,
        limit: int = 3
    ):
        """Search nearest embeddings with flexible filtering"""
        conditions = []
        params = [query_embedding]
        
        if grade is not None:
            conditions.append("grade = %s")
            params.append(grade)
            
        if subject is not None:
            conditions.append("subject ILIKE %s")
            params.append(f"%{subject}%")
            
        if is_arabic is not None:
            conditions.append("is_arabic = %s")
            params.append(is_arabic)
        
        where_clause = ""
        if conditions:
            where_clause = "WHERE " + " AND ".join(conditions)
        
        # Add query_embedding again for ORDER BY
        params.append(query_embedding)
        params.append(limit)
        
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    f"""
                    SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                           is_arabic, embedding <-> %s::vector AS distance
                    FROM educational_chunks
                    {where_clause}
                    ORDER BY embedding <-> %s::vector
                    LIMIT %s;
                    """,
                    params
                )
                return cur.fetchall()

    def get_subjects_by_grade_and_language(self, grade: int, is_arabic: bool) -> List[str]:
        """Get available subjects for a specific grade and language"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT DISTINCT subject
                    FROM educational_chunks
                    WHERE grade = %s AND is_arabic = %s
                    ORDER BY subject;
                    """,
                    (grade, is_arabic)
                )
                return [row['subject'] for row in cur.fetchall()]
