import psycopg2
import os
from dotenv import load_dotenv

load_dotenv()

def setup_mcq_table(drop_existing_table: bool = False):
    try:
        conn = psycopg2.connect(
            host=os.getenv("POSTGRES_HOST", "localhost"),
            port=os.getenv("POSTGRES_PORT", "5432"),
            user=os.getenv("POSTGRES_USER"),
            password=os.getenv("POSTGRES_PASSWORD"),
            dbname=os.getenv("POSTGRES_DB")
        )
        conn.autocommit = True
        
        with conn.cursor() as cur:
            if drop_existing_table:
                print("Dropping existing mcq_questions table...")
                cur.execute("DROP TABLE IF EXISTS mcq_questions CASCADE;")

            print("Creating mcq_questions table...")
            
            # 1. Enable the vector extension (Just in case)
            cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")

            # 2. Create Table with 'embedding vector(1536)'
            cur.execute("""
                CREATE TABLE IF NOT EXISTS mcq_questions (
                    id SERIAL PRIMARY KEY,
                    curriculum TEXT,
                    grade TEXT NOT NULL,
                    subject TEXT NOT NULL,
                    unit TEXT NOT NULL,
                    concept TEXT NOT NULL,
                    question_text TEXT NOT NULL,
                    question_type TEXT,
                    difficulty_level INTEGER,
                    blooms_level TEXT,
                    is_arabic BOOLEAN NOT NULL,
                    correct_answer TEXT NOT NULL,
                    wrong_answer_1 TEXT,
                    wrong_answer_2 TEXT,
                    wrong_answer_3 TEXT,
                    wrong_answer_4 TEXT,
                    question_image_url TEXT,
                    correct_image_url TEXT,
                    wrong_image_url_1 TEXT,
                    wrong_image_url_2 TEXT,
                    wrong_image_url_3 TEXT,
                    wrong_image_url_4 TEXT,
                    hint TEXT,
                    embedding vector(1536),
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                );
            """)
            
            # 3. Create HNSW Index for fast vector search
            print("Creating vector index...")
            cur.execute("""
                CREATE INDEX IF NOT EXISTS idx_mcq_embedding 
                ON mcq_questions USING hnsw (embedding vector_cosine_ops);
            """)

            # Standard indexes
            cur.execute("""
                CREATE INDEX IF NOT EXISTS idx_mcq_topic 
                ON mcq_questions(curriculum, grade, is_arabic, subject, unit, concept);
            """)

            print("MCQ table setup complete with Vector support.")
                
    except Exception as e:
        print(f"Error: {e}")
    finally:
        if 'conn' in locals() and conn: conn.close()

if __name__ == "__main__":
    print("Setting up the MCQ table structure...")
    drop_env = os.getenv("DROP_MCQ_TABLE", "false")
    drop_existing_table = str(drop_env).strip().lower() in ("1", "true", "yes")  
    print("**************************************************")
    print(f"Drop existing table: {drop_existing_table}")
    print("**************************************************")
    setup_mcq_table(drop_existing_table=drop_existing_table)