import torch
import soundfile as sf
import io
import warnings
import logging
import numpy as np

from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse

from config import SUPPORTED_MODELS
from schemas import SynthesisRequest
from tts_service import TTSModel
from utils import split_text_into_chunks

# --- Suppress Warnings ---
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger("transformers").setLevel(logging.ERROR)

# --- Application Setup ---
app = FastAPI()

# Dictionary for application's state (the loaded models)
models = {}

# --- Model Loading on Startup ---
@app.on_event("startup")
def load_all_models():
    use_gpu = torch.cuda.is_available()
    print(f"GPU Available: {use_gpu}")

    for lang, config in SUPPORTED_MODELS.items():
        model = TTSModel(config, use_gpu=use_gpu)
        model.load()
        models[lang] = model

# ---  API Endpoint ---
@app.post("/synthesize")
async def synthesize(request: SynthesisRequest):
    # Select the correct model from our state dictionary
    model = models.get(request.language)
    
    if not model or not model.is_loaded:
        return JSONResponse(content={"error": f"The model for language '{request.language}' is not available."}, status_code=503)

    try:
        # Set character limits with a safety buffer
        char_limit = 140 if request.language == "ar" else 220
        
        text_chunks = split_text_into_chunks(request.text, char_limit)
        print(f"Text split into {len(text_chunks)} chunks.")

        all_audio_chunks = []
        silence_samples = np.zeros(int(24000 * 300 / 1000), dtype=np.float32)

        for i, chunk in enumerate(text_chunks):
            print(f"Synthesizing chunk {i+1}/{len(text_chunks)}: '{chunk}'")
            # Use our powerful OOP model object to synthesize
            audio_chunk = model.synthesize_chunk(chunk)
            
            all_audio_chunks.append(audio_chunk)
            if i < len(text_chunks) - 1:
                all_audio_chunks.append(silence_samples)

        final_audio = np.concatenate(all_audio_chunks)

        buffer = io.BytesIO()
        sf.write(buffer, final_audio, 24000, format='WAV')
        buffer.seek(0)

        return StreamingResponse(buffer, media_type="audio/wav")

    except Exception as e:
        print(f"An error occurred during audio generation: {e}")
        return JSONResponse(content={"error": "Failed to generate audio"}, status_code=500)