import os
import psycopg2
import openai
from psycopg2.extras import execute_values
from dotenv import load_dotenv

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

def get_db_connection():
    return psycopg2.connect(
        dbname=os.getenv("POSTGRES_DB", "embeddings_db"),
        user=os.getenv("POSTGRES_USER", "db_admin"),
        password=os.getenv("POSTGRES_PASSWORD"),
        host=os.getenv("POSTGRES_HOST", "localhost"),
        port=os.getenv("POSTGRES_PORT", 5432)
    )

def chunk_text(text, chunk_size=500, overlap=50):
    chunks = []
    start = 0
    while start < len(text):
        end = min(len(text), start + chunk_size)
        chunks.append(text[start:end])
        start = end - overlap
        if start < 0:
            start = 0
    return chunks

def get_embedding(text):
    response = openai.embeddings.create(
        model="text-embedding-3-small",
        input=text
    )
    return response.data[0].embedding

def main():
    conn = get_db_connection()
    cur = conn.cursor()

    print("Fetching lessons...")
    cur.execute("SELECT id, lesson_text FROM lessons WHERE lesson_text IS NOT NULL;")
    lessons = cur.fetchall()
    total_lessons = len(lessons)
    print(f"Found {total_lessons} lessons to process.")

    all_rows = []
    for idx, (lesson_id, lesson_text) in enumerate(lessons, start=1):
        chunks = chunk_text(lesson_text, chunk_size=500, overlap=50)
        for i, chunk in enumerate(chunks):
            embedding = get_embedding(chunk)
            all_rows.append((lesson_id, i, chunk, embedding))

        progress = (idx / total_lessons) * 100
        print(f"Lesson {idx}/{total_lessons} complete ({progress:.2f}% done, {len(chunks)} chunks)")

        # وقف بعد أول درسين للتجربة
        if idx == 2:
            print("Stopping after first 2 lessons (test mode).")
            break

    if all_rows:
        query = """
        INSERT INTO lesson_embeddings (lesson_id, chunk_index, chunk_text, embedding)
        VALUES %s
        """
        execute_values(cur, query, all_rows)
        conn.commit()

    cur.close()
    conn.close()
    print(f"Inserted {len(all_rows)} embeddings into the database.")

if __name__ == "__main__":
    main()
