mirror of https://github.com/coqui-ai/TTS.git
feat: add adjust_speech_rate function to modify speech speed with more durable latents. also missed tts speed implementations added.
This commit is contained in:
parent
dbf1a08a0d
commit
26128be422
|
@ -283,6 +283,7 @@ class TTS(nn.Module):
|
||||||
style_text=None,
|
style_text=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
split_sentences=split_sentences,
|
split_sentences=split_sentences,
|
||||||
|
speed=speed,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return wav
|
return wav
|
||||||
|
@ -330,13 +331,13 @@ class TTS(nn.Module):
|
||||||
Additional arguments for the model.
|
Additional arguments for the model.
|
||||||
"""
|
"""
|
||||||
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
||||||
|
|
||||||
wav = self.tts(
|
wav = self.tts(
|
||||||
text=text,
|
text=text,
|
||||||
speaker=speaker,
|
speaker=speaker,
|
||||||
language=language,
|
language=language,
|
||||||
speaker_wav=speaker_wav,
|
speaker_wav=speaker_wav,
|
||||||
split_sentences=split_sentences,
|
split_sentences=split_sentences,
|
||||||
|
speed=speed,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -76,6 +77,33 @@ class BaseTTS(BaseTrainerModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("config must be either a *Config or *Args")
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
||||||
|
def adjust_speech_rate(self, gpt_latents, length_scale):
|
||||||
|
if abs(length_scale - 1.0) < 1e-6:
|
||||||
|
return gpt_latents
|
||||||
|
|
||||||
|
B, L, D = gpt_latents.shape
|
||||||
|
target_length = int(L * length_scale)
|
||||||
|
|
||||||
|
assert target_length > 0, f"Invalid target length: {target_length}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
resized = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2),
|
||||||
|
size=target_length,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=True
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if torch.isnan(resized).any():
|
||||||
|
print("Warning: NaN values detected on adjust speech rate")
|
||||||
|
return gpt_latents
|
||||||
|
|
||||||
|
return resized
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Interpolation failed: {e}")
|
||||||
|
return gpt_latents
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
`in_channels` size of the connected layers.
|
`in_channels` size of the connected layers.
|
||||||
|
|
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
return gpt_cond_latents, speaker_embedding
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -409,14 +409,14 @@ class Xtts(BaseTTS):
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings)
|
||||||
settings.update({
|
settings.update({
|
||||||
"gpt_cond_len": config.gpt_cond_len,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"max_ref_len": config.max_ref_len,
|
"max_ref_len": config.max_ref_len,
|
||||||
"sound_norm_refs": config.sound_norm_refs,
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
})
|
})
|
||||||
return self.full_inference(text, speaker_wav, language, **settings)
|
return self.full_inference(text, speaker_wav, language, speed=speed, **settings)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def full_inference(
|
def full_inference(
|
||||||
|
@ -436,6 +436,7 @@ class Xtts(BaseTTS):
|
||||||
gpt_cond_chunk_len=6,
|
gpt_cond_chunk_len=6,
|
||||||
max_ref_len=10,
|
max_ref_len=10,
|
||||||
sound_norm_refs=False,
|
sound_norm_refs=False,
|
||||||
|
speed: float = 1.0,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -496,6 +497,7 @@ class Xtts(BaseTTS):
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
|
speed=speed,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -569,10 +571,7 @@ class Xtts(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
if length_scale != 1.0:
|
if length_scale != 1.0:
|
||||||
gpt_latents = F.interpolate(
|
gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale)
|
||||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
gpt_latents_list.append(gpt_latents.cpu())
|
gpt_latents_list.append(gpt_latents.cpu())
|
||||||
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue