import os
import psycopg2
import pandas as pd
import json
from dotenv import load_dotenv

load_dotenv()

def get_db_connection():
    return psycopg2.connect(
        dbname=os.getenv("POSTGRES_DB", "embeddings_db"),
        user=os.getenv("POSTGRES_USER", "db_admin"),
        password=os.getenv("POSTGRES_PASSWORD"),
        host=os.getenv("POSTGRES_HOST", "localhost"),
        port=os.getenv("POSTGRES_PORT", 5432)
    )

def insert_chunks_from_csv(csv_file: str):
    df = pd.read_csv(csv_file)

    required_cols = [
        "Grade", "Subject", "Unit", "Concept", "Lesson",
        "From page", "To page", "Chunk index", "Chunk text",
        "Is Arabic", "Embedding"
    ]

    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Missing required column in CSV: {col}")

    conn = get_db_connection()
    cur = conn.cursor()

    insert_query = """
        INSERT INTO educational_chunks
        (grade, subject, unit, concept, lesson,
         from_page, to_page, chunk_index, chunk_text,
         is_arabic, embedding)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
    """

    batch_size = 50
    buffer = []

    for idx, row in df.iterrows():
        try:
            embedding = json.loads(row["Embedding"])  # JSON → list
            buffer.append((
                row["Grade"],
                row["Subject"],
                row.get("Unit"),
                row.get("Concept"),
                row.get("Lesson"),
                int(row["From page"]) if not pd.isna(row["From page"]) else None,
                int(row["To page"]) if not pd.isna(row["To page"]) else None,
                int(row["Chunk index"]),
                row["Chunk text"],
                bool(row["Is Arabic"]),
                embedding
            ))
        except Exception as e:
            print(f"Skipping row {idx} due to error: {e}")
            continue

        if len(buffer) >= batch_size:
            cur.executemany(insert_query, buffer)
            conn.commit()
            print(f"Inserted {len(buffer)} rows...")
            buffer = []

    if buffer:
        cur.executemany(insert_query, buffer)
        conn.commit()
        print(f"Inserted final {len(buffer)} rows.")

    cur.close()
    conn.close()
    print("All data inserted successfully.")

if __name__ == "__main__":
    csv_file = "Prime6_en_chunked_with_embeddings.csv" 
    insert_chunks_from_csv(csv_file)
