import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Optional, Dict, Any
import logging
import json
from pgvector.psycopg2 import register_vector
from services.connection_pool import ConnectionPool

logger = logging.getLogger(__name__)


class PGVectorService:
    """Enhanced service for managing embeddings with PostgreSQL pgvector and curriculum structure awareness"""

    def __init__(self, pool_handler: 'ConnectionPool'):
        self.pool_handler = pool_handler
        # Test connection and register vector type
        with self.pool_handler.get_connection() as conn:
            register_vector(conn)

    def insert_embedding(self, id: int, embedding: list):
        """Insert or update an embedding"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO embeddings_table (id, embedding)
                    VALUES (%s, %s)
                    ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;
                    """,
                    (id, embedding),
                )
                conn.commit()

    def search_nearest(self, query_embedding: list, limit: int = 3):
        """Search nearest embeddings using cosine distance (<-> operator)"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT id, embedding, embedding <-> %s AS distance
                    FROM embeddings_table
                    ORDER BY embedding <-> %s
                    LIMIT %s;
                    """,
                    (query_embedding, query_embedding, limit),
                )
                return cur.fetchall()

    def search_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: int,
        subject: str,
        is_arabic: bool,
        limit: int = 3
    ):
        """Search nearest embeddings with filtering by grade, subject, and language"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                           is_arabic, embedding <-> %s::vector AS distance
                    FROM educational_chunks
                    WHERE grade = %s 
                      AND subject ILIKE %s 
                      AND is_arabic = %s
                    ORDER BY embedding <-> %s::vector
                    LIMIT %s;
                    """,
                    (query_embedding, grade, f"%{subject}%", is_arabic, query_embedding, limit),
                )
                return cur.fetchall()

    def search_with_curriculum_context(
        self, 
        query_embedding: list, 
        grade: int, 
        subject: str, 
        is_arabic: bool, 
        limit: int = 3
    ):
        """Enhanced search that includes curriculum position context"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Get curriculum structure first
                curriculum = self.get_curriculum_structure(grade, is_arabic, subject)
                
                # Perform regular search
                cur.execute(
                    """
                    SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                           is_arabic, embedding <-> %s::vector AS distance
                    FROM educational_chunks
                    WHERE grade = %s 
                      AND subject ILIKE %s 
                      AND is_arabic = %s
                    ORDER BY embedding <-> %s::vector
                    LIMIT %s;
                    """,
                    (query_embedding, grade, f"%{subject}%", is_arabic, query_embedding, limit),
                )
                results = cur.fetchall()
                
                # Enhance results with curriculum context
                for result in results:
                    result['curriculum_context'] = self._build_curriculum_context(
                        result, curriculum, grade, is_arabic
                    )
                
                return results

    def search_flexible_filtered_nearest(
        self, 
        query_embedding: list, 
        grade: Optional[int] = None,
        subject: Optional[str] = None,
        is_arabic: Optional[bool] = None,
        limit: int = 3
    ):
        """Search nearest embeddings with flexible filtering"""
        conditions = []
        params = [query_embedding]
        
        if grade is not None:
            conditions.append("grade = %s")
            params.append(grade)
            
        if subject is not None:
            conditions.append("subject ILIKE %s")
            params.append(f"%{subject}%")
            
        if is_arabic is not None:
            conditions.append("is_arabic = %s")
            params.append(is_arabic)
        
        where_clause = ""
        if conditions:
            where_clause = "WHERE " + " AND ".join(conditions)
        
        # Add query_embedding again for ORDER BY
        params.append(query_embedding)
        params.append(limit)
        
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    f"""
                    SELECT id, grade, subject, unit, concept, lesson, chunk_text, 
                           is_arabic, embedding <-> %s::vector AS distance
                    FROM educational_chunks
                    {where_clause}
                    ORDER BY embedding <-> %s::vector
                    LIMIT %s;
                    """,
                    params
                )
                return cur.fetchall()

    def get_curriculum_structure(self, grade: int, is_arabic: bool, subject: str = "Science") -> Optional[Dict]:
        """Get complete curriculum structure for grade/language from JSON-based data"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    """
                    SELECT curriculum_data, created_at
                    FROM curriculum_structure
                    WHERE grade = %s AND is_arabic = %s AND subject = %s;
                    """,
                    (grade, is_arabic, subject)
                )
                result = cur.fetchone()
                if result:
                    return result['curriculum_data']
                return None

    def get_units_for_grade(self, grade: int, is_arabic: bool, subject: str = "Science") -> List[str]:
        """Get list of units for specific grade/language based on JSON structure"""
        curriculum = self.get_curriculum_structure(grade, is_arabic, subject)
        if not curriculum:
            return []
        
        units = []
        if 'units' in curriculum:
            for unit in curriculum['units']:
                unit_name = unit.get('name', '')
                if unit_name:
                    units.append(unit_name)
        return units

    def get_concepts_for_unit(self, grade: int, unit_name: str, is_arabic: bool, subject: str = "Science") -> List[str]:
        """Get concepts within a specific unit based on JSON structure"""
        curriculum = self.get_curriculum_structure(grade, is_arabic, subject)
        if not curriculum:
            return []
        
        concepts = []
        if 'units' in curriculum:
            for unit in curriculum['units']:
                if unit.get('name') == unit_name and 'concepts' in unit:
                    for concept in unit['concepts']:
                        concept_name = concept.get('name', '')
                        if concept_name:
                            concepts.append(concept_name)
        return concepts

    def get_subjects_by_grade_and_language(self, grade: int, is_arabic: bool) -> List[str]:
        """Get available subjects for a specific grade and language"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # First check curriculum_structure table
                cur.execute(
                    """
                    SELECT DISTINCT subject
                    FROM curriculum_structure
                    WHERE grade = %s AND is_arabic = %s
                    ORDER BY subject;
                    """,
                    (grade, is_arabic)
                )
                curriculum_subjects = [row['subject'] for row in cur.fetchall()]
                
                if curriculum_subjects:
                    return curriculum_subjects
                
                # Fallback to educational_chunks table
                cur.execute(
                    """
                    SELECT DISTINCT subject
                    FROM educational_chunks
                    WHERE grade = %s AND is_arabic = %s
                    ORDER BY subject;
                    """,
                    (grade, is_arabic)
                )
                return [row['subject'] for row in cur.fetchall()]

    def insert_curriculum_structure(self, grade: int, is_arabic: bool, subject: str, curriculum_data: Dict):
        """Insert or update curriculum structure"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO curriculum_structure (grade, is_arabic, subject, curriculum_data)
                    VALUES (%s, %s, %s, %s)
                    ON CONFLICT (grade, is_arabic, subject) 
                    DO UPDATE SET curriculum_data = EXCLUDED.curriculum_data, created_at = CURRENT_TIMESTAMP;
                    """,
                    (grade, is_arabic, subject, json.dumps(curriculum_data))
                )
                conn.commit()

    def _build_curriculum_context(self, chunk_result: Dict, curriculum: Optional[Dict], grade: int, is_arabic: bool) -> Dict:
        """Build curriculum context for a chunk result based on JSON structure"""
        if not curriculum:
            return {
                "position": "Unknown",
                "related_concepts": [],
                "unit_overview": "",
                "navigation_hint": ""
            }
        
        chunk_unit = chunk_result.get('unit', '')
        chunk_concept = chunk_result.get('concept', '')
        
        # Find the unit and concept in curriculum structure
        current_unit = None
        current_concept = None
        related_concepts = []
        
        if 'units' in curriculum:
            for unit in curriculum['units']:
                if unit.get('name') == chunk_unit:
                    current_unit = unit
                    if 'concepts' in unit:
                        for concept in unit['concepts']:
                            concept_name = concept.get('name', '')
                            if concept_name == chunk_concept:
                                current_concept = concept
                            elif concept_name:
                                related_concepts.append(concept_name)
                    break
        
        # Build navigation hint
        navigation_hint = ""
        if current_unit:
            unit_name = current_unit.get('name', '')
            if is_arabic:
                navigation_hint = f"هذا جزء من {unit_name}"
                if related_concepts:
                    navigation_hint += f"، والذي يتضمن أيضاً: {', '.join(related_concepts[:3])}"
            else:
                navigation_hint = f"This is part of {unit_name}"
                if related_concepts:
                    navigation_hint += f", which also covers: {', '.join(related_concepts[:3])}"
        
        return {
            "position": f"Grade {grade} → {chunk_unit} → {chunk_concept}",
            "related_concepts": related_concepts,
            "unit_overview": current_unit.get('description', '') if current_unit else "",
            "navigation_hint": navigation_hint
        }

    def classify_query_type(self, query: str, grade: int, is_arabic: bool) -> str:
        """Classify query type for appropriate retrieval strategy"""
        query_lower = query.lower()
        
        # Arabic overview patterns
        arabic_overview_patterns = [
            "ماذا ندرس", "أظهر المنهج", "ما هي المواضيع", "ما المنهج",
            "ما نتعلم", "المحتويات", "الوحدات", "الفصول"
        ]
        
        # English overview patterns  
        english_overview_patterns = [
            "what do we study", "show curriculum", "what topics", "what subjects",
            "curriculum overview", "table of contents", "units", "chapters"
        ]
        
        # Arabic unit/concept navigation patterns
        arabic_navigation_patterns = [
            "ما في الوحدة", "أخبرني عن الوحدة", "محتوى الوحدة", "مفاهيم الوحدة"
        ]
        
        # English unit/concept navigation patterns
        english_navigation_patterns = [
            "what's in unit", "tell me about unit", "unit content", "concepts in unit"
        ]
        
        if is_arabic:
            if any(pattern in query for pattern in arabic_overview_patterns):
                return "overview"
            elif any(pattern in query for pattern in arabic_navigation_patterns):
                return "navigation"
        else:
            if any(pattern in query_lower for pattern in english_overview_patterns):
                return "overview"
            elif any(pattern in query_lower for pattern in english_navigation_patterns):
                return "navigation"
        
        return "specific_content"

    def get_overview_response(self, grade: int, is_arabic: bool, subject: str = "Science") -> str:
        """Generate curriculum overview response based on JSON structure"""
        curriculum = self.get_curriculum_structure(grade, is_arabic, subject)
        if not curriculum:
            if is_arabic:
                return f"عذراً، لا يوجد منهج متاح للصف {grade}"
            else:
                return f"Sorry, no curriculum available for Grade {grade}"
        
        if is_arabic:
            response = f"📚 منهج العلوم للصف {grade}\n\n"
            if 'units' in curriculum:
                for i, unit in enumerate(curriculum['units'], 1):
                    unit_name = unit.get('name', f'الوحدة {i}')
                    response += f"**الوحدة {i}: {unit_name}**\n"
                    if 'concepts' in unit:
                        for concept in unit['concepts']:
                            concept_name = concept.get('name', '')
                            concept_number = concept.get('number', '')
                            if concept_name:
                                response += f"├── المفهوم {concept_number}: {concept_name}\n"
                    response += "\n"
        else:
            response = f"📚 Science Curriculum for Grade {grade}\n\n"
            if 'units' in curriculum:
                for i, unit in enumerate(curriculum['units'], 1):
                    unit_name = unit.get('name', f'Unit {i}')
                    response += f"**Unit {i}: {unit_name}**\n"
                    if 'concepts' in unit:
                        for concept in unit['concepts']:
                            concept_name = concept.get('name', '')
                            concept_number = concept.get('number', '')
                            if concept_name:
                                response += f"├── Concept {concept_number}: {concept_name}\n"
                    response += "\n"
        
        return response

    def get_unit_navigation_response(self, query: str, grade: int, is_arabic: bool, subject: str = "Science") -> str:
        """Generate unit/concept navigation response based on JSON structure"""
        curriculum = self.get_curriculum_structure(grade, is_arabic, subject)
        if not curriculum:
            return self.get_overview_response(grade, is_arabic, subject)

        # Extract unit and concept numbers from query
        import re
        numbers = re.findall(r'(\d+)', query)
        
        unit_num = None
        concept_num = None
        if len(numbers) >= 1:
            unit_num = int(numbers[0]) - 1 # Convert to 0-based index
        if len(numbers) >= 2:
            concept_num = numbers[1]

        response = ""
        if is_arabic:
            response = f"📖 من منهج العلوم للصف {grade}:\n\n"
        else:
            response = f"📖 From Grade {grade} Science Curriculum:\n\n"

        # If unit number is specified and exists
        if unit_num is not None and 0 <= unit_num < len(curriculum.get('units', [])):
            unit = curriculum['units'][unit_num]
            unit_name = unit.get('name', f'Unit {unit_num + 1}')

            # Try to find a specific concept
            if concept_num is not None and 'concepts' in unit:
                for concept in unit['concepts']:
                    if str(concept.get('number')) == concept_num:
                        concept_name = concept.get('name', '')
                        concept_description = concept.get('description', '')
                        if is_arabic:
                            response += f"**الوحدة {unit_num + 1}: {unit_name}**\n\n"
                            response += f"**المفهوم {concept_num}: {concept_name}**\n\n"
                            response += f"{concept_description}\n"
                        else:
                            response += f"**Unit {unit_num + 1}: {unit_name}**\n\n"
                            response += f"**Concept {concept_num}: {concept_name}**\n\n"
                            response += f"{concept_description}\n"
                        return response

            # Fallback to a unit overview if a specific concept wasn't found or requested
            if is_arabic:
                response += f"**الوحدة {unit_num + 1}: {unit_name}**\n\n"
                if 'concepts' in unit:
                    response += "المفاهيم:\n"
                    for concept in unit['concepts']:
                        concept_name = concept.get('name', '')
                        concept_number = concept.get('number', '')
                        if concept_name:
                            response += f"├── المفهوم {concept_number}: {concept_name}\n"
            else:
                response += f"**Unit {unit_num + 1}: {unit_name}**\n\n"
                if 'concepts' in unit:
                    response += "Concepts:\n"
                    for concept in unit['concepts']:
                        concept_name = concept.get('name', '')
                        concept_number = concept.get('number', '')
                        if concept_name:
                            response += f"├── Concept {concept_number}: {concept_name}\n"
                            
            return response

        # Final fallback to a general overview if the unit number is not valid or was not found
        return self.get_overview_response(grade, is_arabic, subject)

    def setup_curriculum_table(self):
        """Create curriculum_structure table if it doesn't exist"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS curriculum_structure (
                        id SERIAL PRIMARY KEY,
                        grade INTEGER NOT NULL,
                        is_arabic BOOLEAN NOT NULL,
                        subject VARCHAR(100) NOT NULL DEFAULT 'Science',
                        curriculum_data JSONB NOT NULL,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        UNIQUE(grade, is_arabic, subject)
                    );
                """)
                
                # Create indexes for better performance
                cur.execute("CREATE INDEX IF NOT EXISTS idx_curriculum_grade_lang ON curriculum_structure(grade, is_arabic);")
                cur.execute("CREATE INDEX IF NOT EXISTS idx_curriculum_subject ON curriculum_structure(subject);")
                
                conn.commit()
                logger.info("Curriculum structure table setup complete")

    def get_all_available_curricula(self) -> List[Dict]:
        """Get all available curricula from the database"""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT grade, is_arabic, subject, 
                           curriculum_data->>'title' as title,
                           created_at
                    FROM curriculum_structure
                    ORDER BY grade, is_arabic, subject;
                """)
                return cur.fetchall()

    def verify_recent_insertions(self, limit: int = 5):
        """
        Fetches and prints the most recently added educational chunks
        to verify a successful ingestion.
        """
        print("\n" + "="*50)
        print("🔍 Verifying recent embeddings in the database...")
        print("="*50)
        
        try:
            with self.pool_handler.get_connection() as conn:
                with conn.cursor(cursor_factory=RealDictCursor) as cur:
                    # Fetches the 5 rows with the highest 'id' (most recent)
                    cur.execute(
                        """
                        SELECT id, grade, subject, unit, concept, chunk_text, is_arabic
                        FROM educational_chunks
                        ORDER BY id DESC
                        LIMIT %s;
                        """,
                        (limit,)
                    )
                    
                    results = cur.fetchall()
                    
                    if not results:
                        print("❌ No data found in the 'educational_chunks' table.")
                        return

                    print(f"✅ Found {len(results)} recent records. Here they are:\n")
                    for row in results:
                        print(f"  - ID: {row['id']}, Grade: {row['grade']}, Arabic: {row['is_arabic']}")
                        print(f"    Unit: {row['unit']}")
                        print(f"    Concept: {row['concept']}")
                        print(f"    Text: '{row['chunk_text'][:80]}...'\n")
            print("="*50)

        except Exception as e:
            print(f"❌ Database verification failed: {e}")

    def insert_mcqs(self, mcq_list: List[Dict]):
        """
        Inserts a batch of MCQs, now including ALL new fields from the updated schema.
        """
        if not mcq_list:
            return

        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                # --- UPDATED INSERT QUERY WITH ALL NEW COLUMNS ---
                insert_query = """
                    INSERT INTO mcq_questions (
                        curriculum, grade, subject, unit, concept, question_text,
                        question_type, difficulty_level, is_arabic, correct_answer,
                        wrong_answer_1, wrong_answer_2, wrong_answer_3, wrong_answer_4,
                        question_image_url, correct_image_url, wrong_image_url_1,
                        wrong_image_url_2, wrong_image_url_3, wrong_image_url_4, hint
                    ) VALUES (
                        %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                        %s, %s, %s, %s, %s, %s, %s
                    );
                """
                # --- UPDATED DATA PREPARATION TO MATCH THE NEW SCHEMA ---
                # Using .get() provides safety against missing keys from the LLM response
                data_to_insert = [
                    (
                        q.get('curriculum'), q.get('grade'), q.get('subject'), q.get('unit'), q.get('concept'),
                        q.get('question_text'), q.get('question_type'), q.get('difficulty_level'),
                        q.get('is_arabic'), q.get('correct_answer'), q.get('wrong_answer_1'),
                        q.get('wrong_answer_2'), q.get('wrong_answer_3'), q.get('wrong_answer_4'),
                        q.get('question_image_url'), q.get('correct_image_url'), q.get('wrong_image_url_1'),
                        q.get('wrong_image_url_2'), q.get('wrong_image_url_3'), q.get('wrong_image_url_4'),
                        q.get('hint')
                    ) for q in mcq_list
                ]
                
                cur.executemany(insert_query, data_to_insert)
                conn.commit()
                logger.info(f"Successfully inserted {len(mcq_list)} MCQs into the database.")

    def get_mcqs(self, curriculum: str, grade: str, subject: str, unit: str, concept: str, is_arabic: bool, limit: Optional[int] = 10) -> List[Dict]:
        """
        Retrieves MCQs for a specific topic and language, now filtering by curriculum.
        If limit is None, it retrieves all matching questions.
        """
        with self.pool_handler.get_connection() as conn:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                
                # --- UPDATED SELECT AND WHERE CLAUSE ---
                query = """
                    SELECT *
                    FROM mcq_questions
                    WHERE curriculum = %s AND grade = %s AND subject = %s AND unit = %s AND concept = %s AND is_arabic = %s
                    ORDER BY created_at DESC
                """
                
                params = (curriculum, grade, subject, unit, concept, is_arabic)
                
                if limit is not None:
                    query += " LIMIT %s;"
                    params += (limit,)
                else:
                    query += ";"

                cur.execute(query, params)
                return cur.fetchall()
            

    def get_distinct_curricula_from_structure(self) -> List[str]:
        """Gets distinct curriculum names from the curriculum_structure table."""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT DISTINCT curriculum_data->>'title' FROM curriculum_structure ORDER BY 1;")
                return [row[0] for row in cur.fetchall() if row[0]]

    def get_distinct_grades_from_structure(self, curriculum: str) -> List[str]:
        """Gets distinct grades for a given curriculum from the curriculum_structure table."""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                # We assume grade is stored as an integer, but return as string for consistency
                cur.execute("""
                    SELECT DISTINCT grade::text FROM curriculum_structure
                    WHERE curriculum_data->>'title' = %s ORDER BY 1;
                """, (curriculum,))
                return [row[0] for row in cur.fetchall() if row[0]]

    def get_distinct_subjects_from_structure(self, curriculum: str, grade: str) -> List[str]:
        """Gets distinct subjects for a given curriculum and grade."""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT DISTINCT subject FROM curriculum_structure
                    WHERE curriculum_data->>'title' = %s AND grade = %s ORDER BY 1;
                """, (curriculum, int(grade))) # Grade is an integer in this table
                return [row[0] for row in cur.fetchall() if row[0]]

    def get_distinct_units_from_structure(self, curriculum: str, grade: str, subject: str) -> List[str]:
        """Gets distinct unit names from the JSONB data in curriculum_structure."""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                # This query uses jsonb_array_elements to expand the 'units' array
                cur.execute("""
                    SELECT DISTINCT unit->>'name'
                    FROM curriculum_structure, jsonb_array_elements(curriculum_data->'units') AS unit
                    WHERE curriculum_data->>'title' = %s AND grade = %s AND subject = %s
                    ORDER BY 1;
                """, (curriculum, int(grade), subject))
                return [row[0] for row in cur.fetchall() if row[0]]

    def get_distinct_concepts_from_structure(self, curriculum: str, grade: str, subject: str, unit: str) -> List[str]:
        """Gets distinct concept names for a given unit from the JSONB data."""
        with self.pool_handler.get_connection() as conn:
            with conn.cursor() as cur:
                # This is a more complex query that expands both units and concepts
                cur.execute("""
                    SELECT DISTINCT concept->>'name'
                    FROM curriculum_structure,
                         jsonb_array_elements(curriculum_data->'units') AS u,
                         jsonb_array_elements(u->'concepts') AS concept
                    WHERE curriculum_data->>'title' = %s
                      AND grade = %s
                      AND subject = %s
                      AND u->>'name' = %s
                    ORDER BY 1;
                """, (curriculum, int(grade), subject, unit))
                return [row[0] for row in cur.fetchall() if row[0]]