import os
import sys
import numpy as np
import faiss
from sklearn.cluster import MiniBatchKMeans
from tqdm import tqdm

"""
This script manually creates a Faiss index for voice conversion features.
It loads pre-extracted features
trains a Faiss index, and saves the populated index to disk.
"""

# --- CONFIGURATION ---
exp_name = "anan-40"   
version = "v2"
feature_dim = 768 
# ---------------------

root_dir = "/home/ec2-user/RVC/Retrieval-based-Voice-Conversion-WebUI"
exp_dir = os.path.join(root_dir, "logs", exp_name)
feature_dir = os.path.join(exp_dir, f"3_feature{feature_dim}")

print(f"--- Starting Index Training for {exp_name} ---")

# 1. Load Features
if not os.path.exists(feature_dir):
    print(f"Error: Directory not found: {feature_dir}")
    sys.exit(1)

listdir_res = list(os.listdir(feature_dir))
if len(listdir_res) == 0:
    print("Error: Feature directory is empty! You need to run Feature Extraction first.")
    sys.exit(1)

print(f"Found {len(listdir_res)} feature files.")

npys = []
# --- Progress Bar for Loading ---
for name in tqdm(sorted(listdir_res), desc="Loading Features", unit="file"):
    phone = np.load(os.path.join(feature_dir, name))
    npys.append(phone)

big_npy = np.concatenate(npys, 0)
print(f"Total features loaded: {big_npy.shape}")

# 2. Shuffle
print("Shuffling features...")
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]

# 3. K-Means (Optimization for large datasets)
if big_npy.shape[0] > 2e5:
    print(f"Dataset is large ({big_npy.shape[0]} rows). Applying K-Means clustering...")
    try:
        big_npy = (
            MiniBatchKMeans(
                n_clusters=10000,
                verbose=True,
                batch_size=256 * 8,
                compute_labels=False,
                init="random",
            )
            .fit(big_npy)
            .cluster_centers_
        )
    except Exception as e:
        print(f"K-Means failed: {e}")

# 4. Train Index
print("Training Faiss Index (Log outputs below)...")
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
index = faiss.index_factory(feature_dim, f"IVF{n_ivf},Flat")
index_ivf = faiss.extract_index_ivf(index)
index_ivf.nprobe = 1

# --- ENABLE VERBOSE MODE (Shows internal progress) ---
index.verbose = True 
# ---------------------------------------------------

index.train(big_npy)

# 5. Add Data and Save
index_name = f"added_IVF{n_ivf}_Flat_nprobe_1_{exp_name}_{version}.index"
save_path = os.path.join(exp_dir, index_name)

print(f"Adding data to index...")
batch_size_add = 8192

# --- Progress Bar for Indexing ---
total_batches = (big_npy.shape[0] + batch_size_add - 1) // batch_size_add

for i in tqdm(range(0, big_npy.shape[0], batch_size_add), desc="Populating Index", total=total_batches, unit="batch"):
    index.add(big_npy[i : i + batch_size_add])

faiss.write_index(index, save_path)
print(f"--------------------------------------------------")
print(f"SUCCESS! Index saved to: {save_path}")
print(f"--------------------------------------------------")