feat: add adjust_speech_rate function to modify speech speed with more durable latents. also missed tts speed implementations added.

This commit is contained in:
isikhi 2024-12-28 23:08:08 +03:00
parent dbf1a08a0d
commit 26128be422
3 changed files with 36 additions and 8 deletions

View File

@ -283,6 +283,7 @@ class TTS(nn.Module):
style_text=None, style_text=None,
reference_speaker_name=None, reference_speaker_name=None,
split_sentences=split_sentences, split_sentences=split_sentences,
speed=speed,
**kwargs, **kwargs,
) )
return wav return wav
@ -330,13 +331,13 @@ class TTS(nn.Module):
Additional arguments for the model. Additional arguments for the model.
""" """
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
wav = self.tts( wav = self.tts(
text=text, text=text,
speaker=speaker, speaker=speaker,
language=language, language=language,
speaker_wav=speaker_wav, speaker_wav=speaker_wav,
split_sentences=split_sentences, split_sentences=split_sentences,
speed=speed,
**kwargs, **kwargs,
) )
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -76,6 +77,33 @@ class BaseTTS(BaseTrainerModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
def adjust_speech_rate(self, gpt_latents, length_scale):
if abs(length_scale - 1.0) < 1e-6:
return gpt_latents
B, L, D = gpt_latents.shape
target_length = int(L * length_scale)
assert target_length > 0, f"Invalid target length: {target_length}"
try:
resized = F.interpolate(
gpt_latents.transpose(1, 2),
size=target_length,
mode="linear",
align_corners=True
).transpose(1, 2)
if torch.isnan(resized).any():
print("Warning: NaN values detected on adjust speech rate")
return gpt_latents
return resized
except RuntimeError as e:
print(f"Interpolation failed: {e}")
return gpt_latents
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers. `in_channels` size of the connected layers.

View File

@ -379,7 +379,7 @@ class Xtts(BaseTTS):
return gpt_cond_latents, speaker_embedding return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs):
"""Synthesize speech with the given input text. """Synthesize speech with the given input text.
Args: Args:
@ -409,14 +409,14 @@ class Xtts(BaseTTS):
settings.update(kwargs) # allow overriding of preset settings with kwargs settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None: if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings)
settings.update({ settings.update({
"gpt_cond_len": config.gpt_cond_len, "gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len, "gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len, "max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs, "sound_norm_refs": config.sound_norm_refs,
}) })
return self.full_inference(text, speaker_wav, language, **settings) return self.full_inference(text, speaker_wav, language, speed=speed, **settings)
@torch.inference_mode() @torch.inference_mode()
def full_inference( def full_inference(
@ -436,6 +436,7 @@ class Xtts(BaseTTS):
gpt_cond_chunk_len=6, gpt_cond_chunk_len=6,
max_ref_len=10, max_ref_len=10,
sound_norm_refs=False, sound_norm_refs=False,
speed: float = 1.0,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
""" """
@ -496,6 +497,7 @@ class Xtts(BaseTTS):
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
do_sample=do_sample, do_sample=do_sample,
speed=speed,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
@ -569,10 +571,7 @@ class Xtts(BaseTTS):
) )
if length_scale != 1.0: if length_scale != 1.0:
gpt_latents = F.interpolate( gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale)
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
gpt_latents_list.append(gpt_latents.cpu()) gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())