import torch
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

from config import TTSConfig

class TTSModel:
    """
    A class that encapsulates a Coqui TTS model, handling loading,
    speaker latent calculation, and inference. This is the core OOP abstraction.
    """
    def __init__(self, config: TTSConfig, use_gpu: bool = False):
        self.config = config
        self.use_gpu = use_gpu
        self.model = None
        self.gpt_cond_latent = None
        self.speaker_embedding = None
        self.is_loaded = False

    def load(self):
        """Loads the model and computes speaker latents."""
        print(f"Loading model for language: '{self.config.language}'...")
        try:
            # Handle the two different ways of loading models
            if self.config.language == "ar":
                # Local, fine-tuned model
                conf = XttsConfig()
                conf.load_json(self.config.config_path)
                self.model = Xtts.init_from_config(conf)
                self.model.load_checkpoint(
                    conf,
                    checkpoint_dir=self.config.model_name_or_path,
                    vocab_path=self.config.vocab_path,
                    use_deepspeed=False
                )
                if self.use_gpu:
                    self.model.cuda()
                
                # Calculate latents using the model's method
                self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(
                    audio_path=[self.config.speaker_wav]
                )
            else:
                # High-level API model
                api_model = TTS(model_name=self.config.model_name_or_path, gpu=self.use_gpu)
                self.model = api_model.synthesizer.tts_model
                
                # Calculate latents using the API model's method
                self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(
                    audio_path=[self.config.speaker_wav]
                )

            self.is_loaded = True
            print(f"Model for '{self.config.language}' loaded successfully.")

        except Exception as e:
            print(f"FATAL ERROR: Could not load model for '{self.config.language}'. Error: {e}")
            self.is_loaded = False
    
    def synthesize_chunk(self, text: str):
        """Runs inference on a single text chunk."""
        if not self.is_loaded:
            raise RuntimeError(f"Model for language '{self.config.language}' is not loaded.")

        out = self.model.inference(
            text=text,
            language=self.config.language,
            speaker_embedding=self.speaker_embedding,
            gpt_cond_latent=self.gpt_cond_latent,
            temperature=0.1
        )
        return out["wav"]