import numpy as np
from typing import Dict, List
from schemas import SequenceSynthesisRequest, Segment
from tts_service import TTSModel
from utils import split_text_into_chunks, sanitize_text, translate_equations_in_text

class SynthesizerService:
    """
    This service class orchestrates the entire synthesis process, from chunking
    to batching and stitching. It holds the application's state (the models).
    """
    def __init__(self, models: Dict[str, TTSModel]):
        self.models = models

    def synthesize_simple(self, text: str, language: str) -> np.ndarray:
        """Handles the logic for the simple /synthesize endpoint."""
        model = self.models.get(language)
        if not model or not model.is_loaded:
            raise ValueError(f"Model for language '{language}' is not available.")
        
        char_limit = 140 if language == "ar" else 220
        text = translate_equations_in_text(text, language)
        text = sanitize_text(text)
        text_chunks = split_text_into_chunks(text, char_limit)
        print(f"Text split into {len(text_chunks)} chunks.")

        audio_chunks = model.synthesize_batch(text_chunks)
        
        silence_samples = np.zeros(int(24000 * 300 / 1000), dtype=np.float32)
        final_audio_pieces = []
        
        for i, audio in enumerate(audio_chunks):
            final_audio_pieces.append(audio)
            if i < len(audio_chunks) - 1:
                final_audio_pieces.append(silence_samples)
        
        return np.concatenate(final_audio_pieces)

    def synthesize_sequence(self, segments: List[Segment]) -> np.ndarray:
        """Handles the complex logic for the /synthesize_sequence endpoint."""
        silence_samples = np.zeros(int(24000 * 300 / 1000), dtype=np.float32)
        
        chunk_metadata = []
        for seg_idx, segment in enumerate(segments):
            lang = segment.language
            if lang not in self.models or not self.models[lang].is_loaded:
                raise ValueError(f"Model for language '{lang}' is not available.")
            
            char_limit = 1000 if lang == "ar" else 1000
            segment.text = translate_equations_in_text(segment.text, lang)
            segment.text = sanitize_text(segment.text)
            print(f"Segment {seg_idx} ({lang}) text: {segment.text}")
            text_chunks = split_text_into_chunks(segment.text, char_limit)
            print(f"chunks: {text_chunks}")
            
            for chunk_idx, text in enumerate(text_chunks):
                chunk_metadata.append({
                    'segment_idx': seg_idx,
                    'lang': lang,
                    'text': text
                })
        
        lang_groups = {}
        for idx, meta in enumerate(chunk_metadata):
            lang = meta['lang']
            if lang not in lang_groups:
                lang_groups[lang] = []
            lang_groups[lang].append((idx, meta['text']))
            
        audio_results = [None] * len(chunk_metadata)
        
        for lang, chunks in lang_groups.items():
            model = self.models[lang]
            indices = [idx for idx, _ in chunks]
            texts = [text for _, text in chunks]
            
            print(f"Processing {len(texts)} {lang} chunks in parallel batches...")
            audio_chunks = model.synthesize_batch(texts)
            
            for idx, audio in zip(indices, audio_chunks):
                audio_results[idx] = audio
        
        segments_audio = {}
        for idx, meta in enumerate(chunk_metadata):
            seg_idx = meta['segment_idx']
            if seg_idx not in segments_audio:
                segments_audio[seg_idx] = []
            segments_audio[seg_idx].append(audio_results[idx])
            
        final_audio_pieces = []
        for seg_idx in sorted(segments_audio.keys()):
            segment_audio = np.concatenate(segments_audio[seg_idx])
            final_audio_pieces.append(segment_audio)
            if seg_idx < len(segments) - 1:
                final_audio_pieces.append(silence_samples)
        
        if not final_audio_pieces:
            raise ValueError("Audio generation resulted in empty output.")
            
        return np.concatenate(final_audio_pieces)