import torch
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from typing import List
import numpy as np
import os
from config import TTSConfig, inference_config

class TTSModel:
    """
    A class that encapsulates a Coqui TTS model with batch processing support.
    """
    def __init__(self, config: TTSConfig, use_gpu: bool = False, batch_size: int = 4):
        self.config = config
        self.use_gpu = use_gpu
        self.batch_size = batch_size
        self.model = None
        self.gpt_cond_latent = None
        self.speaker_embedding = None
        self.is_loaded = False

    def load(self):
        """Loads the model and computes speaker latents from a directory of WAV files."""
        print(f"Loading model for language: '{self.config.language}'...")
        try:
            speaker_dir = self.config.speaker_directory
            if not os.path.isdir(speaker_dir):
                raise ValueError(f"Speaker directory not found: {speaker_dir}")

            speaker_wav_paths = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')]
            
            if not speaker_wav_paths:
                raise ValueError(f"No .wav files found in directory: {speaker_dir}")
            
            print(f"Found {len(speaker_wav_paths)} reference audio files for voice cloning.")

            # Load the base model 
            conf = XttsConfig()
            conf.load_json(self.config.config_path)
            self.model = Xtts.init_from_config(conf)
            self.model.load_checkpoint(
                conf,
                checkpoint_dir=self.config.model_name_or_path,
                vocab_path=self.config.vocab_path,
                use_deepspeed=False
            )
            if self.use_gpu:
                self.model.cuda()
        
            
            print(f"Computing speaker characteristics from {len(speaker_wav_paths)} files...")
            self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(
                audio_path=speaker_wav_paths
            )

            self.is_loaded = True
            print(f"Model for '{self.config.language}' loaded successfully with batch size {self.batch_size}.")
        except Exception as e:
            print(f"FATAL ERROR: Could not load model for '{self.config.language}'. Error: {e}")
            self.is_loaded = False

    def synthesize_chunk(self, text: str):
        if not self.is_loaded: raise RuntimeError(f"Model for '{self.config.language}' is not loaded.")
        out = self.model.inference(
            text=text,
            language=self.config.language,
            speaker_embedding=self.speaker_embedding,
            gpt_cond_latent=self.gpt_cond_latent,
            temperature=inference_config["temperature"],
            length_penalty=inference_config["length_penalty"],
            repetition_penalty=inference_config["repetition_penalty"],
            enable_text_splitting=inference_config["enable_text_splitting"]
        )

        return out["wav"]

    def synthesize_batch(self, texts: List[str]) -> List[np.ndarray]:
        if not self.is_loaded:
            raise RuntimeError(f"Model for '{self.config.language}' is not loaded.")
        if not texts:
            return []

        all_audio = []

        texts = [t if isinstance(t, str) else " ".join(t) for t in texts]

        for i in range(0, len(texts), self.batch_size):
            batch_texts = texts[i:i + self.batch_size]
            print(f"Processing batch {i//self.batch_size + 1}: {len(batch_texts)} chunks")
            batch_audio = []
            try:
                with torch.no_grad():
                    outputs = [
                        self.model.inference(
                            text=text.strip(),
                            language=self.config.language,
                            speaker_embedding=self.speaker_embedding,
                            gpt_cond_latent=self.gpt_cond_latent,
                            temperature=inference_config["temperature"],
                            length_penalty=inference_config["length_penalty"],
                            repetition_penalty=inference_config["repetition_penalty"],
                            enable_text_splitting=inference_config["enable_text_splitting"]
                        )
                        for text in batch_texts
                    ]
                    batch_audio = [out["wav"] for out in outputs]

                all_audio.extend(batch_audio)
                if self.use_gpu:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"GPU OOM error. Falling back to sequential processing for this batch.")
                    for text in batch_texts:
                        out = self.model.inference(
                            text=text.strip(),
                            language=self.config.language,
                            speaker_embedding=self.speaker_embedding,
                            gpt_cond_latent=self.gpt_cond_latent,
                            temperature=inference_config["temperature"],
                            length_penalty=inference_config["length_penalty"],
                            repetition_penalty=inference_config["repetition_penalty"],
                            enable_text_splitting=inference_config["enable_text_splitting"]
                        )
                        all_audio.append(out["wav"])
                    if self.use_gpu:
                        torch.cuda.empty_cache()
                else:
                    raise e

        return all_audio
