import psycopg2
import os

schema_sql = """
-- Create students table
CREATE TABLE IF NOT EXISTS students (
    id SERIAL PRIMARY KEY,
    student_id VARCHAR(50) UNIQUE NOT NULL,
    student_name VARCHAR(100),
    grade VARCHAR(20),
    language BOOLEAN,
    nationality VARCHAR(20) NOT NULL DEFAULT 'EGYPTIAN'
);

-- Create chat_history table
CREATE TABLE IF NOT EXISTS chat_history (
    id SERIAL PRIMARY KEY,
    student_id VARCHAR(50) NOT NULL,
    role VARCHAR(20) NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
    content TEXT NOT NULL,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (student_id) REFERENCES students(student_id) ON DELETE CASCADE
);

-- Create indexes for better performance
CREATE INDEX IF NOT EXISTS idx_chat_history_student_id ON chat_history(student_id);
CREATE INDEX IF NOT EXISTS idx_chat_history_created_at ON chat_history(created_at);
CREATE INDEX IF NOT EXISTS idx_students_nationality ON students(nationality);

-- Insert dummy data for testing
INSERT INTO students (student_id, student_name, grade, language, nationality) VALUES 
    ('student_001', 'Ahmed Ali', 'prime4', TRUE, 'EGYPTIAN'),
    ('student_002', 'Sara Hassan', 'prime6', FALSE, 'SAUDI'),
    ('student_003', 'Mona Adel', 'prime5', TRUE, 'EGYPTIAN'),
    ('student_004', 'Omar Youssef', 'prime6', FALSE, 'SAUDI')
ON CONFLICT (student_id) DO NOTHING;
"""

drop_all_tables_sql = """
DO $$
DECLARE
    rec RECORD;
BEGIN
    -- drop all tables in public schema
    FOR rec IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP
        EXECUTE 'DROP TABLE IF EXISTS "' || rec.tablename || '" CASCADE';
    END LOOP;
END $$;
"""

conn = psycopg2.connect(
    host=os.getenv("POSTGRES_HOST", "localhost"),
    port=os.getenv("POSTGRES_PORT", "5432"),
    user=os.getenv("POSTGRES_USER"),
    password=os.getenv("POSTGRES_PASSWORD"),
    dbname=os.getenv("POSTGRES_DB")
)
conn.autocommit = True

with conn.cursor() as cur:
    # Drop all existing tables (uncomment if needed)
    cur.execute(drop_all_tables_sql)
    cur.execute(schema_sql)

    # Verifications: Select from students and chat_history tables
    print("Students table rows:")
    cur.execute("SELECT * FROM students ORDER BY id;")
    students = cur.fetchall()
    for row in students:
        print(row)

    print("\nChat_history table rows:")
    cur.execute("SELECT * FROM chat_history ORDER BY id;")
    chat_history = cur.fetchall()
    for row in chat_history:
        print(row)

conn.close()
