import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Optional
# Import the pgvector adapter
from pgvector.psycopg2 import register_vector


class PGVectorService:
    """Service for managing embeddings with PostgreSQL pgvector"""

    def __init__(self):
        self.conn = psycopg2.connect(
            host=os.getenv("POSTGRES_HOST", "postgres"),
            user=os.getenv("POSTGRES_USER"),
            password=os.getenv("POSTGRES_PASSWORD"),
            dbname=os.getenv("POSTGRES_DB"),
        )
        # Register the vector type with the connection
        register_vector(self.conn)

    def insert_embedding(self, id: int, embedding: list):
        """Insert or update an embedding"""
        with self.conn.cursor() as cur:
            cur.execute(
                """
                INSERT INTO embeddings_table (id, embedding)
                VALUES (%s, %s)
                ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;
                """,
                (id, embedding),
            )
            self.conn.commit()

    def search_nearest(self, query_embedding: list, limit: int = 3):
        """Search nearest embeddings using cosine distance (<-> operator)"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                """
                SELECT id, embedding, embedding <-> %s AS distance
                FROM embeddings_table
                ORDER BY embedding <-> %s
                LIMIT %s;
                """,
                (query_embedding, query_embedding, limit),
            )
            return cur.fetchall()

    def search_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: int,
        subject: str,
        is_arabic: bool,
        limit: int = 3
    ):
        """Search nearest embeddings with filtering by grade, subject, and language"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                """
                SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                       is_arabic, embedding <-> %s::vector AS distance
                FROM educational_chunks
                WHERE grade = %s 
                  AND subject ILIKE %s 
                  AND is_arabic = %s
                ORDER BY embedding <-> %s::vector
                LIMIT %s;
                """,
                (query_embedding, grade, f"%{subject}%", is_arabic, query_embedding, limit),
            )
            return cur.fetchall()

    def search_flexible_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: Optional[int] = None,
        subject: Optional[str] = None,
        is_arabic: Optional[bool] = None,
        limit: int = 3
    ):
        """Search nearest embeddings with flexible filtering"""
        conditions = []
        params = [query_embedding]
        
        if grade is not None:
            conditions.append("grade = %s")
            params.append(grade)
            
        if subject is not None:
            conditions.append("subject ILIKE %s")
            params.append(f"%{subject}%")
            
        if is_arabic is not None:
            conditions.append("is_arabic = %s")
            params.append(is_arabic)
        
        where_clause = ""
        if conditions:
            where_clause = "WHERE " + " AND ".join(conditions)
        
        # Add query_embedding again for ORDER BY
        params.append(query_embedding)
        params.append(limit)
        
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                f"""
                SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                       is_arabic, embedding <-> %s::vector AS distance
                FROM educational_chunks
                {where_clause}
                ORDER BY embedding <-> %s::vector
                LIMIT %s;
                """,
                params
            )
            return cur.fetchall()

    def get_subjects_by_grade_and_language(self, grade: str, is_arabic: bool) -> List[str]:
        """Get available subjects for a specific grade and language"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            # Extract numeric part from grade string
            grade_number = ''.join(filter(str.isdigit, grade)) if grade else None
            
            if grade_number:
                cur.execute(
                    """
                    SELECT DISTINCT subject
                    FROM educational_chunks
                    WHERE grade = %s AND is_arabic = %s
                    ORDER BY subject;
                    """,
                    (int(grade_number), is_arabic)
                )
            else:
                # Fallback if grade parsing fails
                cur.execute(
                    """
                    SELECT DISTINCT subject
                    FROM educational_chunks
                    WHERE is_arabic = %s
                    ORDER BY subject;
                    """,
                    (is_arabic,)
                )
            return [row['subject'] for row in cur.fetchall()]

    def close(self):
        if self.conn:
            self.conn.close()