faster TTS

parent b138ddef
...@@ -12,15 +12,17 @@ class TTSConfig(BaseModel): ...@@ -12,15 +12,17 @@ class TTSConfig(BaseModel):
ARABIC_MODEL_CONFIG = TTSConfig( ARABIC_MODEL_CONFIG = TTSConfig(
language="ar", language="ar",
model_name_or_path="./model/EGTTS-V0.1/", model_name_or_path="./model/EGTTS-V0.1/",
speaker_directory="salma", speaker_directory="anan",
config_path="./model/EGTTS-V0.1/config.json", config_path="./model/EGTTS-V0.1/config.json",
vocab_path="./model/EGTTS-V0.1/vocab.json" vocab_path="./model/EGTTS-V0.1/vocab.json"
) )
ENGLISH_MODEL_CONFIG = TTSConfig( ENGLISH_MODEL_CONFIG = TTSConfig(
language="en", language="en",
model_name_or_path="tts_models/multilingual/multi-dataset/xtts_v2", model_name_or_path="./model_en/",
speaker_directory="anan" speaker_directory="anan",
config_path="./model_en/config.json",
vocab_path="./model_en/vocab.json"
) )
...@@ -28,3 +30,10 @@ SUPPORTED_MODELS = { ...@@ -28,3 +30,10 @@ SUPPORTED_MODELS = {
"ar": ARABIC_MODEL_CONFIG, "ar": ARABIC_MODEL_CONFIG,
"en": ENGLISH_MODEL_CONFIG, "en": ENGLISH_MODEL_CONFIG,
} }
inference_config = {
"temperature": 0.1,
"length_penalty": 0.9,
"repetition_penalty": 1.2,
"enable_text_splitting": True,
}
import torch
import soundfile as sf
import io
import warnings
import logging
import numpy as np
import nltk
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from config import SUPPORTED_MODELS
from schemas import SynthesisRequest, SequenceSynthesisRequest
from tts_service import TTSModel
from synthesizer_service import SynthesizerService # <-- Import the new service
# --- Suppress Warnings ---
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger("transformers").setLevel(logging.ERROR)
# --- Application Setup ---
app = FastAPI()
# This single object will now hold our application's state and logic
synthesizer = None
# --- Model Loading on Startup ---
@app.on_event("startup")
def startup_event():
global synthesizer
nltk.download('punkt')
nltk.download('punkt_tab')
use_gpu = torch.cuda.is_available()
print(f"GPU Available: {use_gpu}")
models = {}
batch_size = 12
for lang, config in SUPPORTED_MODELS.items():
model = TTSModel(config, use_gpu=use_gpu, batch_size=batch_size)
model.load()
models[lang] = model
# Create a single instance of our synthesizer service
synthesizer = SynthesizerService(models)
print("Synthesizer service is ready.")
# --- Helper function to create the audio response ---
def create_audio_response(audio_array: np.ndarray) -> StreamingResponse:
buffer = io.BytesIO()
sf.write(buffer, audio_array, 24000, format='WAV')
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
# --- API Endpoints (Now clean and thin) ---
@app.post("/synthesize")
async def synthesize(request: SynthesisRequest):
try:
final_audio = synthesizer.synthesize_simple(request.text, request.language)
return create_audio_response(final_audio)
except Exception as e:
print(f"An error occurred during simple synthesis: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate audio: {str(e)}")
@app.post("/synthesize_sequence")
async def synthesize_sequence(request: SequenceSynthesisRequest):
try:
final_audio = synthesizer.synthesize_sequence(request.segments)
return create_audio_response(final_audio)
except Exception as e:
print(f"An error occurred during sequence synthesis: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate audio sequence: {str(e)}")
\ No newline at end of file
...@@ -46,7 +46,7 @@ class SynthesizerService: ...@@ -46,7 +46,7 @@ class SynthesizerService:
if lang not in self.models or not self.models[lang].is_loaded: if lang not in self.models or not self.models[lang].is_loaded:
raise ValueError(f"Model for language '{lang}' is not available.") raise ValueError(f"Model for language '{lang}' is not available.")
char_limit = 140 if lang == "ar" else 200 char_limit = 1000 if lang == "ar" else 1000
segment.text = translate_equations_in_text(segment.text, lang) segment.text = translate_equations_in_text(segment.text, lang)
segment.text = sanitize_text(segment.text) segment.text = sanitize_text(segment.text)
print(f"Segment {seg_idx} ({lang}) text: {segment.text}") print(f"Segment {seg_idx} ({lang}) text: {segment.text}")
......
...@@ -5,7 +5,7 @@ from TTS.tts.models.xtts import Xtts ...@@ -5,7 +5,7 @@ from TTS.tts.models.xtts import Xtts
from typing import List from typing import List
import numpy as np import numpy as np
import os import os
from config import TTSConfig from config import TTSConfig, inference_config
class TTSModel: class TTSModel:
""" """
...@@ -35,8 +35,7 @@ class TTSModel: ...@@ -35,8 +35,7 @@ class TTSModel:
print(f"Found {len(speaker_wav_paths)} reference audio files for voice cloning.") print(f"Found {len(speaker_wav_paths)} reference audio files for voice cloning.")
# Load the base model (logic is the same) # Load the base model
if self.config.language == "ar":
conf = XttsConfig() conf = XttsConfig()
conf.load_json(self.config.config_path) conf.load_json(self.config.config_path)
self.model = Xtts.init_from_config(conf) self.model = Xtts.init_from_config(conf)
...@@ -48,9 +47,7 @@ class TTSModel: ...@@ -48,9 +47,7 @@ class TTSModel:
) )
if self.use_gpu: if self.use_gpu:
self.model.cuda() self.model.cuda()
else:
api_model = TTS(model_name=self.config.model_name_or_path, gpu=self.use_gpu)
self.model = api_model.synthesizer.tts_model
print(f"Computing speaker characteristics from {len(speaker_wav_paths)} files...") print(f"Computing speaker characteristics from {len(speaker_wav_paths)} files...")
self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents( self.gpt_cond_latent, self.speaker_embedding = self.model.get_conditioning_latents(
...@@ -65,30 +62,71 @@ class TTSModel: ...@@ -65,30 +62,71 @@ class TTSModel:
def synthesize_chunk(self, text: str): def synthesize_chunk(self, text: str):
if not self.is_loaded: raise RuntimeError(f"Model for '{self.config.language}' is not loaded.") if not self.is_loaded: raise RuntimeError(f"Model for '{self.config.language}' is not loaded.")
out = self.model.inference(text=text, language=self.config.language, speaker_embedding=self.speaker_embedding, gpt_cond_latent=self.gpt_cond_latent, temperature=0.1) out = self.model.inference(
text=text,
language=self.config.language,
speaker_embedding=self.speaker_embedding,
gpt_cond_latent=self.gpt_cond_latent,
temperature=inference_config["temperature"],
length_penalty=inference_config["length_penalty"],
repetition_penalty=inference_config["repetition_penalty"],
enable_text_splitting=inference_config["enable_text_splitting"]
)
return out["wav"] return out["wav"]
def synthesize_batch(self, texts: List[str]) -> List[np.ndarray]: def synthesize_batch(self, texts: List[str]) -> List[np.ndarray]:
if not self.is_loaded: raise RuntimeError(f"Model for '{self.config.language}' is not loaded.") if not self.is_loaded:
if not texts: return [] raise RuntimeError(f"Model for '{self.config.language}' is not loaded.")
if not texts:
return []
all_audio = [] all_audio = []
texts = [t if isinstance(t, str) else " ".join(t) for t in texts]
for i in range(0, len(texts), self.batch_size): for i in range(0, len(texts), self.batch_size):
batch_texts = texts[i:i + self.batch_size] batch_texts = texts[i:i + self.batch_size]
print(f"Processing batch {i//self.batch_size + 1}: {len(batch_texts)} chunks") print(f"Processing batch {i//self.batch_size + 1}: {len(batch_texts)} chunks")
batch_audio = [] batch_audio = []
try: try:
with torch.no_grad(): with torch.no_grad():
for text in batch_texts: outputs = [
out = self.model.inference(text=text, language=self.config.language, speaker_embedding=self.speaker_embedding, gpt_cond_latent=self.gpt_cond_latent, temperature=0.1) self.model.inference(
batch_audio.append(out["wav"]) text=text.strip(),
language=self.config.language,
speaker_embedding=self.speaker_embedding,
gpt_cond_latent=self.gpt_cond_latent,
temperature=inference_config["temperature"],
length_penalty=inference_config["length_penalty"],
repetition_penalty=inference_config["repetition_penalty"],
enable_text_splitting=inference_config["enable_text_splitting"]
)
for text in batch_texts
]
batch_audio = [out["wav"] for out in outputs]
all_audio.extend(batch_audio) all_audio.extend(batch_audio)
if self.use_gpu: torch.cuda.empty_cache() if self.use_gpu:
torch.cuda.empty_cache()
except RuntimeError as e: except RuntimeError as e:
if "out of memory" in str(e): if "out of memory" in str(e):
print(f"GPU OOM error. Falling back to sequential processing for this batch.") print(f"GPU OOM error. Falling back to sequential processing for this batch.")
for text in batch_texts: for text in batch_texts:
out = self.model.inference(text=text, language=self.config.language, speaker_embedding=self.speaker_embedding, gpt_cond_latent=self.gpt_cond_latent, temperature=0.1) out = self.model.inference(
text=text.strip(),
language=self.config.language,
speaker_embedding=self.speaker_embedding,
gpt_cond_latent=self.gpt_cond_latent,
temperature=inference_config["temperature"],
length_penalty=inference_config["length_penalty"],
repetition_penalty=inference_config["repetition_penalty"],
enable_text_splitting=inference_config["enable_text_splitting"]
)
all_audio.append(out["wav"]) all_audio.append(out["wav"])
if self.use_gpu: torch.cuda.empty_cache() if self.use_gpu:
else: raise e torch.cuda.empty_cache()
else:
raise e
return all_audio return all_audio
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment