Update ForwardTTS for multi-speaker

This commit is contained in:
Eren Gölge 2021-10-20 18:16:41 +00:00
parent 0ebc2a400e
commit aa25f70b95
1 changed files with 89 additions and 57 deletions

View File

@ -13,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
@ -31,9 +32,6 @@ class ForwardTTSArgs(Coqpit):
hidden_channels (int): hidden_channels (int):
Number of base hidden channels of the model. Defaults to 512. Number of base hidden channels of the model. Defaults to 512.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
use_aligner (bool): use_aligner (bool):
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
@ -86,12 +84,6 @@ class ForwardTTSArgs(Coqpit):
decoder_params (str): decoder_params (str):
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
use_d_vetor (bool):
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of channels of the d-vectors. Defaults to 0.
detach_duration_predictor (bool): detach_duration_predictor (bool):
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
does not pass to the earlier layers. Defaults to True. does not pass to the earlier layers. Defaults to True.
@ -99,12 +91,26 @@ class ForwardTTSArgs(Coqpit):
max_duration (int): max_duration (int):
Maximum duration accepted by the model. Defaults to 75. Maximum duration accepted by the model. Defaults to 75.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
speakers_file (str):
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
speaker_embedding_channels (int):
Number of speaker embedding channels. Defaults to 256.
use_d_vector_file (bool):
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of d-vector channels. Defaults to 0.
""" """
num_chars: int = None num_chars: int = None
out_channels: int = 80 out_channels: int = 80
hidden_channels: int = 384 hidden_channels: int = 384
num_speakers: int = 0
use_aligner: bool = True use_aligner: bool = True
use_pitch: bool = True use_pitch: bool = True
pitch_predictor_hidden_channels: int = 256 pitch_predictor_hidden_channels: int = 256
@ -125,10 +131,14 @@ class ForwardTTSArgs(Coqpit):
decoder_params: dict = field( decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
) )
use_d_vector: bool = False
d_vector_dim: int = 0
detach_duration_predictor: bool = False detach_duration_predictor: bool = False
max_duration: int = 75 max_duration: int = 75
num_speakers: int = 1
use_speaker_embedding: bool = False
speakers_file: str = None
use_d_vector_file: bool = False
d_vector_dim: int = None
d_vector_file: str = None
class ForwardTTS(BaseTTS): class ForwardTTS(BaseTTS):
@ -150,6 +160,8 @@ class ForwardTTS(BaseTTS):
Args: Args:
config (Coqpit): Model coqpit class. config (Coqpit): Model coqpit class.
speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models.
Defaults to None.
Examples: Examples:
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
@ -158,10 +170,13 @@ class ForwardTTS(BaseTTS):
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config) super().__init__(config)
self.speaker_manager = speaker_manager
self.init_multispeaker(config)
self.max_duration = self.args.max_duration self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch self.use_pitch = self.args.use_pitch
@ -178,7 +193,7 @@ class ForwardTTS(BaseTTS):
self.args.hidden_channels, self.args.hidden_channels,
self.args.encoder_type, self.args.encoder_type,
self.args.encoder_params, self.args.encoder_params,
self.args.d_vector_dim, self.embedded_speaker_dim,
) )
if self.args.positional_encoding: if self.args.positional_encoding:
@ -192,7 +207,7 @@ class ForwardTTS(BaseTTS):
) )
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim, self.args.hidden_channels + self.embedded_speaker_dim,
self.args.duration_predictor_hidden_channels, self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size, self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p, self.args.duration_predictor_dropout_p,
@ -200,7 +215,7 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
self.pitch_predictor = DurationPredictor( self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.args.d_vector_dim, self.args.hidden_channels + self.embedded_speaker_dim,
self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p, self.args.pitch_predictor_dropout_p,
@ -212,19 +227,37 @@ class ForwardTTS(BaseTTS):
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
) )
if self.args.num_speakers > 1 and not self.args.use_d_vector:
# speaker embedding layer
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if self.args.use_aligner: if self.args.use_aligner:
self.aligner = AlignmentNetwork( self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
) )
def init_multispeaker(self, config: Coqpit):
"""Init for multi-speaker training.
Args:
config (Coqpit): Model configuration.
"""
self.embedded_speaker_dim = 0
# init speaker manager
if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
# init d-vector embedding
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod @staticmethod
def generate_attn(dr, x_mask, y_mask=None): def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the durations. """Generate an attention mask from the durations.
@ -289,18 +322,6 @@ class ForwardTTS(BaseTTS):
o_dr = torch.round(o_dr) o_dr = torch.round(o_dr)
return o_dr return o_dr
@staticmethod
def _concat_speaker_embedding(o_en, g):
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
o_en = torch.cat([o_en, g_exp], 1)
return o_en
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder( def _forward_encoder(
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
@ -309,7 +330,7 @@ class ForwardTTS(BaseTTS):
1. Embed speaker IDs if multi-speaker mode. 1. Embed speaker IDs if multi-speaker mode.
2. Embed character sequences. 2. Embed character sequences.
3. Run the encoder network. 3. Run the encoder network.
4. Concat speaker embedding to the encoder output for the duration predictor. 4. Sum encoder outputs and speaker embeddings
Args: Args:
x (torch.LongTensor): Input sequence IDs. x (torch.LongTensor): Input sequence IDs.
@ -327,19 +348,18 @@ class ForwardTTS(BaseTTS):
- g: :math:`(B, C)` - g: :math:`(B, C)`
""" """
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] g = self.emb_g(g) # [B, C, 1]
if g is not None: if g is not None:
g = g.unsqueeze(-1) g = g.unsqueeze(-1)
# [B, T, C] # [B, T, C]
x_emb = self.emb(x) x_emb = self.emb(x)
# encoder pass # encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# speaker conditioning for duration predictor # speaker conditioning
# TODO: try different ways of conditioning
if g is not None: if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g) o_en = o_en + g
else: return o_en, x_mask, g, x_emb
o_en_dp = o_en
return o_en, o_en_dp, x_mask, g, x_emb
def _forward_decoder( def _forward_decoder(
self, self,
@ -373,9 +393,6 @@ class ForwardTTS(BaseTTS):
# positional encoding # positional encoding
if hasattr(self, "pos_encoder"): if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask) o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
# decoder pass # decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g) o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2) return o_de.transpose(1, 2), attn.transpose(1, 2)
@ -457,6 +474,19 @@ class ForwardTTS(BaseTTS):
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
if d_vectors is not None and speaker_ids is not None:
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
if speaker_ids is not None and not hasattr(self, "emb_g"):
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
g = speaker_ids if speaker_ids is not None else d_vectors
return g
def forward( def forward(
self, self,
x: torch.LongTensor, x: torch.LongTensor,
@ -487,17 +517,17 @@ class ForwardTTS(BaseTTS):
- g: :math:`[B, C]` - g: :math:`[B, C]`
- pitch: :math:`[B, 1, T]` - pitch: :math:`[B, 1, T]`
""" """
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = self._set_speaker_input(aux_input)
# compute sequence masks # compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
# encoder pass # encoder pass
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
# duration predictor pass # duration predictor pass
if self.args.detach_duration_predictor: if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
else: else:
o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations # generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
@ -517,10 +547,12 @@ class ForwardTTS(BaseTTS):
o_pitch = None o_pitch = None
avg_pitch = None avg_pitch = None
if self.args.use_pitch: if self.args.use_pitch:
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# decoder pass # decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) o_de, attn = self._forward_decoder(
o_en, dr, x_mask, y_lengths, g=None
) # TODO: maybe pass speaker embedding (g) too
outputs = { outputs = {
"model_outputs": o_de, # [B, T, C] "model_outputs": o_de, # [B, T, C]
"durations_log": o_dr_log.squeeze(1), # [B, T] "durations_log": o_dr_log.squeeze(1), # [B, T]
@ -551,22 +583,22 @@ class ForwardTTS(BaseTTS):
- x_lengths: [B] - x_lengths: [B]
- g: [B, C] - g: [B, C]
""" """
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = self._set_speaker_input(aux_input)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
# encoder pass # encoder pass
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
# duration predictor pass # duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)
# pitch predictor pass # pitch predictor pass
o_pitch = None o_pitch = None
if self.args.use_pitch: if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# decoder pass # decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
outputs = { outputs = {
"model_outputs": o_de, "model_outputs": o_de,
"alignments": attn, "alignments": attn,