import torch
import soundfile as sf
import io
import warnings
import logging
import numpy as np
import nltk

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse

from config import SUPPORTED_MODELS
from schemas import SynthesisRequest, SequenceSynthesisRequest
from tts_service import TTSModel
from synthesizer_service import SynthesizerService # <-- Import the new service

# --- Suppress Warnings ---
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger("transformers").setLevel(logging.ERROR)

# --- Application Setup ---
app = FastAPI()

# This single object will now hold our application's state and logic
synthesizer = None

# --- Model Loading on Startup ---
@app.on_event("startup")
def startup_event():
    global synthesizer
    nltk.download('punkt')
    nltk.download('punkt_tab')
    
    use_gpu = torch.cuda.is_available()
    print(f"GPU Available: {use_gpu}")

    models = {}
    batch_size = 12
    for lang, config in SUPPORTED_MODELS.items():
        model = TTSModel(config, use_gpu=use_gpu, batch_size=batch_size)
        model.load()
        models[lang] = model
    
    # Create a single instance of our synthesizer service
    synthesizer = SynthesizerService(models)
    print("Synthesizer service is ready.")

# --- Helper function to create the audio response ---
def create_audio_response(audio_array: np.ndarray) -> StreamingResponse:
    buffer = io.BytesIO()
    sf.write(buffer, audio_array, 24000, format='WAV')
    buffer.seek(0)
    return StreamingResponse(buffer, media_type="audio/wav")

# --- API Endpoints (Now clean and thin) ---

@app.post("/synthesize")
async def synthesize(request: SynthesisRequest):
    try:
        final_audio = synthesizer.synthesize_simple(request.text, request.language)
        return create_audio_response(final_audio)
    except Exception as e:
        print(f"An error occurred during simple synthesis: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to generate audio: {str(e)}")

@app.post("/synthesize_sequence")
async def synthesize_sequence(request: SequenceSynthesisRequest):
    try:
        final_audio = synthesizer.synthesize_sequence(request.segments)
        return create_audio_response(final_audio)
    except Exception as e:
        print(f"An error occurred during sequence synthesis: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to generate audio sequence: {str(e)}")