import os
import sys
import time
import uuid
import numpy as np
from scipy.io import wavfile
from dotenv import load_dotenv
from openai import OpenAI
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from pydantic import BaseModel

"""
RVC Voice Agent API
This FastAPI application integrates OpenAI's TTS capabilities with the RVC voice conversion model.
It accepts text input, generates speech using OpenAI, processes it through RVC, and returns the final audio.
"""

# Assuming this exists in your local files as per your snippet
try:
    from text_processing import prepare_text_for_audio
except ImportError:
    # Fallback if file is missing during testing
    def prepare_text_for_audio(text): return text 

# 1. Setup Environment
now_dir = os.getcwd()
sys.path.append(now_dir)
load_dotenv()

# Set paths for RVC
os.environ["weight_root"] = "assets/weights"
os.environ["index_root"] = "logs"

from configs.config import Config
from infer.modules.vc.modules import VC

# --- CONFIGURATION ---
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MODEL_NAME = "anan-40.pth"
INDEX_PATH = "logs/anan-40/added_IVF3961_Flat_nprobe_1_anan-40_v2.index"

# Optimized RVC Parameters
RVC_PARAMS = {
    "f0_up_key": -3,
    "f0_method": "rmvpe",
    "index_rate": 0,        
    "filter_radius": 3,
    "resample_sr": 0,
    "rms_mix_rate": 0.25,
    "protect": 0.33
}

# --- GLOBAL VARIABLES ---
app = FastAPI(title="RVC Voice Agent API")
vc_instance = None
openai_client = None

# --- DATA MODELS ---
class TextRequest(BaseModel):
    text: str
    speed: float = 1.0  # <--- NEW: Optional arg, default is 1.0 (Range 0.25 - 4.0)

# --- HELPER FUNCTIONS ---
def perform_warmup(vc):
    print(">>> [Warmup] Running dummy inference...")
    dummy_path = "/dev/shm/warmup.wav"
    sr = 16000
    silent_audio = np.zeros(sr, dtype=np.int16)
    wavfile.write(dummy_path, sr, silent_audio)
    try:
        vc.vc_single(0, dummy_path, 0, None, "rmvpe", INDEX_PATH, "", 0.75, 3, 0, 0.25, 0.33)
        print(">>> [Warmup] Success! System is ready.")
    except Exception as e:
        print(f">>> [Warmup] Warning: {e}")
    if os.path.exists(dummy_path):
        os.remove(dummy_path)

def cleanup_files(file_paths: list):
    """Background task to delete temp files after response is sent"""
    for path in file_paths:
        if os.path.exists(path):
            try:
                os.remove(path)
            except:
                pass

# --- LIFECYCLE EVENTS ---
@app.on_event("startup")
async def startup_event():
    global vc_instance, openai_client
    
    if not OPENAI_API_KEY:
        print("CRITICAL ERROR: OPENAI_API_KEY not found.")
        sys.exit(1)
    
    openai_client = OpenAI(api_key=OPENAI_API_KEY)

    print(">>> [Startup] Initializing RVC Config...")
    config = Config()
    config.weight_root = "assets/weights"
    config.index_root = "logs"
    
    vc_instance = VC(config)
    print(f">>> [Startup] Loading Model '{MODEL_NAME}' to GPU...")
    try:
        vc_instance.get_vc(MODEL_NAME)
    except Exception as e:
        print(f"CRITICAL ERROR loading model: {e}")
        sys.exit(1)

    perform_warmup(vc_instance)

# --- API ENDPOINTS ---
@app.get("/")
def root():
    return {"status": "running", "model": MODEL_NAME}

@app.post("/generate_audio")
async def generate_audio(request: TextRequest, background_tasks: BackgroundTasks):
    start_time = time.time()
    
    unique_id = str(uuid.uuid4())
    temp_openai_path = f"/dev/shm/openai_{unique_id}.wav"
    final_output_path = f"/dev/shm/rvc_{unique_id}.wav"

    # --- STEP 1: PRE-PROCESS TEXT ---
    print(f"Original Text: {request.text[:50]}...")
    clean_text = prepare_text_for_audio(request.text)
    print(f"Cleaned Text:  {clean_text[:50]}...")
    
    # --- STEP 2: OpenAI TTS ---
    try:
        response = openai_client.audio.speech.create(
            model="gpt-4o-mini-tts",    
            voice="alloy",   
            input=clean_text, 
            speed=request.speed, # <--- NEW: Using the speed from request
            response_format="wav"
        )
        with open(temp_openai_path, "wb") as f:
            f.write(response.content)
    except Exception as e:
        print(f"OpenAI Error: {e}")
        raise HTTPException(status_code=500, detail=f"OpenAI Error: {str(e)}")

    # --- STEP 3: RVC Inference ---
    try:
        info, audio_tuple = vc_instance.vc_single(
            0, 
            temp_openai_path, 
            RVC_PARAMS["f0_up_key"], 
            None, 
            RVC_PARAMS["f0_method"], 
            INDEX_PATH, 
            "", 
            RVC_PARAMS["index_rate"], 
            RVC_PARAMS["filter_radius"], 
            RVC_PARAMS["resample_sr"], 
            RVC_PARAMS["rms_mix_rate"], 
            RVC_PARAMS["protect"]
        )
        
        if audio_tuple is None:
             raise Exception(f"RVC Conversion returned None. Info: {info}")

        tgt_sr, audio_opt = audio_tuple
        wavfile.write(final_output_path, tgt_sr, audio_opt)

    except Exception as e:
        if os.path.exists(temp_openai_path): os.remove(temp_openai_path)
        print(f"RVC Error: {e}")
        raise HTTPException(status_code=500, detail=f"RVC Error: {str(e)}")

    # --- STEP 4: Return & Cleanup ---
    background_tasks.add_task(cleanup_files, [temp_openai_path, final_output_path])
    
    total_time = time.time() - start_time
    print(f"Request processed in {total_time:.4f}s with speed {request.speed}")

    return FileResponse(
        final_output_path, 
        media_type="audio/wav", 
        filename="response.wav"
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=5000)