mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTS for multi-speaker
This commit is contained in:
parent
0ebc2a400e
commit
aa25f70b95
|
@ -13,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
|||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -31,9 +32,6 @@ class ForwardTTSArgs(Coqpit):
|
|||
hidden_channels (int):
|
||||
Number of base hidden channels of the model. Defaults to 512.
|
||||
|
||||
num_speakers (int):
|
||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||
|
||||
use_aligner (bool):
|
||||
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
|
||||
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
|
||||
|
@ -86,12 +84,6 @@ class ForwardTTSArgs(Coqpit):
|
|||
decoder_params (str):
|
||||
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||
|
||||
use_d_vetor (bool):
|
||||
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
|
||||
|
||||
d_vector_dim (int):
|
||||
Number of channels of the d-vectors. Defaults to 0.
|
||||
|
||||
detach_duration_predictor (bool):
|
||||
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
||||
does not pass to the earlier layers. Defaults to True.
|
||||
|
@ -99,12 +91,26 @@ class ForwardTTSArgs(Coqpit):
|
|||
max_duration (int):
|
||||
Maximum duration accepted by the model. Defaults to 75.
|
||||
|
||||
num_speakers (int):
|
||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||
|
||||
speakers_file (str):
|
||||
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||
|
||||
speaker_embedding_channels (int):
|
||||
Number of speaker embedding channels. Defaults to 256.
|
||||
|
||||
use_d_vector_file (bool):
|
||||
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||
|
||||
d_vector_dim (int):
|
||||
Number of d-vector channels. Defaults to 0.
|
||||
|
||||
"""
|
||||
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 384
|
||||
num_speakers: int = 0
|
||||
use_aligner: bool = True
|
||||
use_pitch: bool = True
|
||||
pitch_predictor_hidden_channels: int = 256
|
||||
|
@ -125,10 +131,14 @@ class ForwardTTSArgs(Coqpit):
|
|||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
use_d_vector: bool = False
|
||||
d_vector_dim: int = 0
|
||||
detach_duration_predictor: bool = False
|
||||
max_duration: int = 75
|
||||
num_speakers: int = 1
|
||||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = None
|
||||
d_vector_file: str = None
|
||||
|
||||
|
||||
class ForwardTTS(BaseTTS):
|
||||
|
@ -150,6 +160,8 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
Args:
|
||||
config (Coqpit): Model coqpit class.
|
||||
speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
||||
|
@ -158,10 +170,13 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit):
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.max_duration = self.args.max_duration
|
||||
self.use_aligner = self.args.use_aligner
|
||||
self.use_pitch = self.args.use_pitch
|
||||
|
@ -178,7 +193,7 @@ class ForwardTTS(BaseTTS):
|
|||
self.args.hidden_channels,
|
||||
self.args.encoder_type,
|
||||
self.args.encoder_params,
|
||||
self.args.d_vector_dim,
|
||||
self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
if self.args.positional_encoding:
|
||||
|
@ -192,7 +207,7 @@ class ForwardTTS(BaseTTS):
|
|||
)
|
||||
|
||||
self.duration_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.args.d_vector_dim,
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.duration_predictor_hidden_channels,
|
||||
self.args.duration_predictor_kernel_size,
|
||||
self.args.duration_predictor_dropout_p,
|
||||
|
@ -200,7 +215,7 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
if self.args.use_pitch:
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.args.d_vector_dim,
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
self.args.pitch_predictor_kernel_size,
|
||||
self.args.pitch_predictor_dropout_p,
|
||||
|
@ -212,19 +227,37 @@ class ForwardTTS(BaseTTS):
|
|||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
if self.args.num_speakers > 1 and not self.args.use_d_vector:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
|
||||
if self.args.use_aligner:
|
||||
self.aligner = AlignmentNetwork(
|
||||
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init for multi-speaker training.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init d-vector embedding
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the durations.
|
||||
|
@ -289,18 +322,6 @@ class ForwardTTS(BaseTTS):
|
|||
o_dr = torch.round(o_dr)
|
||||
return o_dr
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(o_en, g):
|
||||
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
|
||||
o_en = torch.cat([o_en, g_exp], 1)
|
||||
return o_en
|
||||
|
||||
def _sum_speaker_embedding(self, x, g):
|
||||
# project g to decoder dim.
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g)
|
||||
return x + g
|
||||
|
||||
def _forward_encoder(
|
||||
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
|
@ -309,7 +330,7 @@ class ForwardTTS(BaseTTS):
|
|||
1. Embed speaker IDs if multi-speaker mode.
|
||||
2. Embed character sequences.
|
||||
3. Run the encoder network.
|
||||
4. Concat speaker embedding to the encoder output for the duration predictor.
|
||||
4. Sum encoder outputs and speaker embeddings
|
||||
|
||||
Args:
|
||||
x (torch.LongTensor): Input sequence IDs.
|
||||
|
@ -327,19 +348,18 @@ class ForwardTTS(BaseTTS):
|
|||
- g: :math:`(B, C)`
|
||||
"""
|
||||
if hasattr(self, "emb_g"):
|
||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||
g = self.emb_g(g) # [B, C, 1]
|
||||
if g is not None:
|
||||
g = g.unsqueeze(-1)
|
||||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# encoder pass
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
# speaker conditioning for duration predictor
|
||||
# speaker conditioning
|
||||
# TODO: try different ways of conditioning
|
||||
if g is not None:
|
||||
o_en_dp = self._concat_speaker_embedding(o_en, g)
|
||||
else:
|
||||
o_en_dp = o_en
|
||||
return o_en, o_en_dp, x_mask, g, x_emb
|
||||
o_en = o_en + g
|
||||
return o_en, x_mask, g, x_emb
|
||||
|
||||
def _forward_decoder(
|
||||
self,
|
||||
|
@ -373,9 +393,6 @@ class ForwardTTS(BaseTTS):
|
|||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
# speaker embedding
|
||||
if g is not None:
|
||||
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
||||
|
@ -457,6 +474,19 @@ class ForwardTTS(BaseTTS):
|
|||
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
||||
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
speaker_ids = aux_input.get("speaker_ids", None)
|
||||
|
||||
if d_vectors is not None and speaker_ids is not None:
|
||||
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
||||
|
||||
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
||||
|
||||
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||
return g
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.LongTensor,
|
||||
|
@ -487,17 +517,17 @@ class ForwardTTS(BaseTTS):
|
|||
- g: :math:`[B, C]`
|
||||
- pitch: :math:`[B, 1, T]`
|
||||
"""
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
g = self._set_speaker_input(aux_input)
|
||||
# compute sequence masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
||||
# encoder pass
|
||||
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||
o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||
# duration predictor pass
|
||||
if self.args.detach_duration_predictor:
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
||||
else:
|
||||
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
# generate attn mask from predicted durations
|
||||
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||
|
@ -517,10 +547,12 @@ class ForwardTTS(BaseTTS):
|
|||
o_pitch = None
|
||||
avg_pitch = None
|
||||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
|
||||
o_de, attn = self._forward_decoder(
|
||||
o_en, dr, x_mask, y_lengths, g=None
|
||||
) # TODO: maybe pass speaker embedding (g) too
|
||||
outputs = {
|
||||
"model_outputs": o_de, # [B, T, C]
|
||||
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
||||
|
@ -551,22 +583,22 @@ class ForwardTTS(BaseTTS):
|
|||
- x_lengths: [B]
|
||||
- g: [B, C]
|
||||
"""
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
g = self._set_speaker_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||
# encoder pass
|
||||
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
# pitch predictor pass
|
||||
o_pitch = None
|
||||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
|
||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
||||
outputs = {
|
||||
"model_outputs": o_de,
|
||||
"alignments": attn,
|
||||
|
|
Loading…
Reference in New Issue