import os
import sys
import time
import traceback
import numpy as np
from scipy.io import wavfile
from dotenv import load_dotenv
from openai import OpenAI

# Setup
now_dir = os.getcwd()
sys.path.append(now_dir)
load_dotenv()

os.environ["weight_root"] = "assets/weights"
os.environ["index_root"] = "logs"

from configs.config import Config
from infer.modules.vc.modules import VC

# --- CONFIGURATION ---
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MODEL_NAME = "anan-40.pth"
INDEX_PATH = "logs/anan-40/added_IVF3961_Flat_nprobe_1_anan-40_v2.index" 

RVC_PARAMS = {
    "f0_up_key": 0,
    "f0_method": "rmvpe",
    "index_rate": 0,
    "filter_radius": 3,
    "resample_sr": 0,
    "rms_mix_rate": 0.25,
    "protect": 0.33
}

def init_rvc_model():
    print(">>> [1/3] Initializing RVC Configuration...")
    config = Config()
    config.weight_root = "assets/weights"
    config.index_root = "logs"
    vc = VC(config)
    
    print(f">>> [2/3] Loading Model '{MODEL_NAME}' to GPU...")
    try:
        vc.get_vc(MODEL_NAME)
    except Exception as e:
        print(f"CRITICAL ERROR loading model: {e}")
        sys.exit(1)
        
    print(">>> [3/3] Warming up (Loading RMVPE & Hubert into VRAM)...")
    perform_warmup(vc)
    return vc

def perform_warmup(vc):
    dummy_path = "/dev/shm/warmup.wav"
    sr = 16000
    silent_audio = np.zeros(sr, dtype=np.int16)
    wavfile.write(dummy_path, sr, silent_audio)
    try:
        vc.vc_single(0, dummy_path, 0, None, "rmvpe", INDEX_PATH, "", 0.75, 3, 0, 0.25, 0.33)
        print(">>> Warm-up Complete! System is hot.")
    except Exception as e:
        print(f">>> Warm-up warning: {e}")
    if os.path.exists(dummy_path):
        os.remove(dummy_path)

def generate_openai_audio(client, text, output_path):
    print(f"\n--- Generating OpenAI Audio ({len(text)} chars) ---")
    start = time.time()
    try:
        response = client.audio.speech.create(
            model="gpt-4o-mini-tts",    
            voice="alloy",   
            input=text,
            response_format="wav"
        )
        
        with open(output_path, "wb") as f:
            f.write(response.content)
        
        end = time.time()
        print(f"OpenAI TTS Time: {end - start:.4f} seconds")
        return True
    except Exception as e:
        print(f"OpenAI Error: {e}")
        return False

def main():
    if not OPENAI_API_KEY:
        print("Error: OPENAI_API_KEY not found.")
        return

    vc_instance = init_rvc_model()
    openai_client = OpenAI(api_key=OPENAI_API_KEY)
    
    print("\n" + "="*60)
    print(" 🚀 AWS SYSTEM READY (HOT-LOADED)")
    print("="*60)

    while True:
        try:
            user_text = input("\n📝 Enter text: ")
            if user_text.lower() in ['exit', 'quit']:
                break
            if not user_text.strip():
                continue

            temp_input = "/dev/shm/temp_openai.wav" 
            output_wav = "output_response.wav"

            # 1. OpenAI Generation
            if not generate_openai_audio(openai_client, user_text, temp_input):
                continue

            print("--- Converting Voice (RVC) ---")
            rvc_start = time.time()

            info, audio_tuple = vc_instance.vc_single(
                0, 
                temp_input, 
                RVC_PARAMS["f0_up_key"], 
                None, 
                RVC_PARAMS["f0_method"], 
                INDEX_PATH, 
                "", 
                RVC_PARAMS["index_rate"], 
                RVC_PARAMS["filter_radius"], 
                RVC_PARAMS["resample_sr"], 
                RVC_PARAMS["rms_mix_rate"], 
                RVC_PARAMS["protect"]
            )

            rvc_end = time.time()
            
            if audio_tuple is not None:
                tgt_sr, audio_opt = audio_tuple
                wavfile.write(output_wav, tgt_sr, audio_opt)
                print(f"✅ Success! Saved to: {output_wav}")
                print(f"⚡ RVC Inference Time: {rvc_end - rvc_start:.4f} seconds")
            else:
                print(f"❌ RVC Conversion Failed. Info: {info}")

            if os.path.exists(temp_input):
                os.remove(temp_input)

        except KeyboardInterrupt:
            print("\nStopping...")
            break
        except Exception as e:
            print(f"Unexpected Error: {e}")
            traceback.print_exc()

if __name__ == "__main__":
    main()