import os
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import List, Dict, Optional, Tuple
import logging

logger = logging.getLogger(__name__)


class ChatDatabaseService:
    """Simple service for managing chat history in PostgreSQL"""

    def __init__(self):
        self.conn = psycopg2.connect(
            host=os.getenv("POSTGRES_HOST", "postgres"),
            user=os.getenv("POSTGRES_USER"),
            password=os.getenv("POSTGRES_PASSWORD"),
            dbname=os.getenv("POSTGRES_DB"),
        )

    def get_student_nationality(self, student_id: str) -> Optional[str]:
        """Get student nationality from database"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                "SELECT nationality FROM students WHERE student_id = %s",
                (student_id,)
            )
            result = cur.fetchone()
            return result["nationality"] if result else None

    def get_student_info(self, student_id: str) -> Optional[Dict]:
        """Get complete student information from database"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                """
                SELECT student_id, student_name, grade, language, nationality 
                FROM students 
                WHERE student_id = %s
                """,
                (student_id,)
            )
            result = cur.fetchone()
            if result:
                return {
                    'student_id': result['student_id'],
                    'student_name': result['student_name'],  # Added this line
                    'grade': result['grade'],
                    'is_arabic': result['language'],  # Convert language boolean to is_arabic
                    'nationality': result['nationality']
                }
            return None

    def get_student_grade_and_language(self, student_id: str) -> Optional[Tuple[int, bool]]:
        """Get student grade and language preference"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                "SELECT grade, language FROM students WHERE student_id = %s",
                (student_id,)
            )
            result = cur.fetchone()
            if result:
                return (result["grade"], result["language"])
            return None

    def get_chat_history(self, student_id: str, limit: int = 20) -> List[Dict[str, str]]:
        """Get chat history for a student, returns in chronological order"""
        with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
            cur.execute(
                """
                SELECT role, content
                FROM chat_history
                WHERE student_id = %s
                ORDER BY created_at DESC
                LIMIT %s;
                """,
                (student_id, limit)
            )
            results = cur.fetchall()
            # Return in chronological order (oldest first)
            return [{"role": row["role"], "content": row["content"]} for row in reversed(results)]

    def add_message(self, student_id: str, role: str, content: str):
        """Add a message to chat history"""
        with self.conn.cursor() as cur:
            cur.execute(
                """
                INSERT INTO chat_history (student_id, role, content)
                VALUES (%s, %s, %s);
                """,
                (student_id, role, content)
            )
            self.conn.commit()

    def clear_history(self, student_id: str):
        """Clear chat history for a student"""
        with self.conn.cursor() as cur:
            cur.execute(
                "DELETE FROM chat_history WHERE student_id = %s",
                (student_id,)
            )
            self.conn.commit()

    def limit_history(self, student_id: str, max_messages: int = 40):
        """Keep only recent messages for a student"""
        with self.conn.cursor() as cur:
            cur.execute(
                """
                DELETE FROM chat_history 
                WHERE student_id = %s 
                AND role != 'system'
                AND id NOT IN (
                    SELECT id FROM chat_history 
                    WHERE student_id = %s AND role != 'system'
                    ORDER BY created_at DESC 
                    LIMIT %s
                );
                """,
                (student_id, student_id, max_messages)
            )
            self.conn.commit()

    def update_student_info(self, student_id: str, grade: Optional[int] = None, 
                           language: Optional[bool] = None, nationality: Optional[str] = None):
        """Update student information"""
        updates = []
        params = []
        
        if grade is not None:
            updates.append("grade = %s")
            params.append(grade)
        
        if language is not None:
            updates.append("language = %s")
            params.append(language)
            
        if nationality is not None:
            updates.append("nationality = %s")
            params.append(nationality)
        
        if updates:
            params.append(student_id)
            with self.conn.cursor() as cur:
                cur.execute(
                    f"""
                    UPDATE students 
                    SET {', '.join(updates)}
                    WHERE student_id = %s
                    """,
                    params
                )
                self.conn.commit()

    def create_student(self, student_id: str, student_name: str, grade: str, 
                      language: bool, nationality: str = 'EGYPTIAN'):
        """Create a new student record"""
        with self.conn.cursor() as cur:
            cur.execute(
                """
                INSERT INTO students (student_id, student_name, grade, language, nationality)
                VALUES (%s, %s, %s, %s, %s)
                ON CONFLICT (student_id) DO NOTHING;
                """,
                (student_id, student_name, grade, language, nationality)
            )
            self.conn.commit()

    def close(self):
        if self.conn:
            self.conn.close()