diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b83f12d4..b2c41df5 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -13,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram @@ -31,9 +32,6 @@ class ForwardTTSArgs(Coqpit): hidden_channels (int): Number of base hidden channels of the model. Defaults to 512. - num_speakers (int): - Number of speakers for the speaker embedding layer. Defaults to 0. - use_aligner (bool): Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the @@ -86,12 +84,6 @@ class ForwardTTSArgs(Coqpit): decoder_params (str): Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` - use_d_vetor (bool): - Whether to use precomputed d-vectors for multi-speaker training. Defaults to False. - - d_vector_dim (int): - Number of channels of the d-vectors. Defaults to 0. - detach_duration_predictor (bool): Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss does not pass to the earlier layers. Defaults to True. @@ -99,12 +91,26 @@ class ForwardTTSArgs(Coqpit): max_duration (int): Maximum duration accepted by the model. Defaults to 75. + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + """ num_chars: int = None out_channels: int = 80 hidden_channels: int = 384 - num_speakers: int = 0 use_aligner: bool = True use_pitch: bool = True pitch_predictor_hidden_channels: int = 256 @@ -125,10 +131,14 @@ class ForwardTTSArgs(Coqpit): decoder_params: dict = field( default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} ) - use_d_vector: bool = False - d_vector_dim: int = 0 detach_duration_predictor: bool = False max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None class ForwardTTS(BaseTTS): @@ -150,6 +160,8 @@ class ForwardTTS(BaseTTS): Args: config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. Examples: >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs @@ -158,10 +170,13 @@ class ForwardTTS(BaseTTS): """ # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): super().__init__(config) + self.speaker_manager = speaker_manager + self.init_multispeaker(config) + self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner self.use_pitch = self.args.use_pitch @@ -178,7 +193,7 @@ class ForwardTTS(BaseTTS): self.args.hidden_channels, self.args.encoder_type, self.args.encoder_params, - self.args.d_vector_dim, + self.embedded_speaker_dim, ) if self.args.positional_encoding: @@ -192,7 +207,7 @@ class ForwardTTS(BaseTTS): ) self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.args.d_vector_dim, + self.args.hidden_channels + self.embedded_speaker_dim, self.args.duration_predictor_hidden_channels, self.args.duration_predictor_kernel_size, self.args.duration_predictor_dropout_p, @@ -200,7 +215,7 @@ class ForwardTTS(BaseTTS): if self.args.use_pitch: self.pitch_predictor = DurationPredictor( - self.args.hidden_channels + self.args.d_vector_dim, + self.args.hidden_channels + self.embedded_speaker_dim, self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, @@ -212,19 +227,37 @@ class ForwardTTS(BaseTTS): padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), ) - if self.args.num_speakers > 1 and not self.args.use_d_vector: - # speaker embedding layer - self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim) - nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - - if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels: - self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) - if self.args.use_aligner: self.aligner = AlignmentNetwork( in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels ) + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + @staticmethod def generate_attn(dr, x_mask, y_mask=None): """Generate an attention mask from the durations. @@ -289,18 +322,6 @@ class ForwardTTS(BaseTTS): o_dr = torch.round(o_dr) return o_dr - @staticmethod - def _concat_speaker_embedding(o_en, g): - g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] - o_en = torch.cat([o_en, g_exp], 1) - return o_en - - def _sum_speaker_embedding(self, x, g): - # project g to decoder dim. - if hasattr(self, "proj_g"): - g = self.proj_g(g) - return x + g - def _forward_encoder( self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: @@ -309,7 +330,7 @@ class ForwardTTS(BaseTTS): 1. Embed speaker IDs if multi-speaker mode. 2. Embed character sequences. 3. Run the encoder network. - 4. Concat speaker embedding to the encoder output for the duration predictor. + 4. Sum encoder outputs and speaker embeddings Args: x (torch.LongTensor): Input sequence IDs. @@ -327,19 +348,18 @@ class ForwardTTS(BaseTTS): - g: :math:`(B, C)` """ if hasattr(self, "emb_g"): - g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + g = self.emb_g(g) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) # [B, T, C] x_emb = self.emb(x) # encoder pass o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) - # speaker conditioning for duration predictor + # speaker conditioning + # TODO: try different ways of conditioning if g is not None: - o_en_dp = self._concat_speaker_embedding(o_en, g) - else: - o_en_dp = o_en - return o_en, o_en_dp, x_mask, g, x_emb + o_en = o_en + g + return o_en, x_mask, g, x_emb def _forward_decoder( self, @@ -373,9 +393,6 @@ class ForwardTTS(BaseTTS): # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) - # speaker embedding - if g is not None: - o_en_ex = self._sum_speaker_embedding(o_en_ex, g) # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de.transpose(1, 2), attn.transpose(1, 2) @@ -457,6 +474,19 @@ class ForwardTTS(BaseTTS): alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + def forward( self, x: torch.LongTensor, @@ -487,17 +517,17 @@ class ForwardTTS(BaseTTS): - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` """ - g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + g = self._set_speaker_input(aux_input) # compute sequence masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # encoder pass - o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) # duration predictor pass if self.args.detach_duration_predictor: - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) else: - o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) @@ -517,10 +547,12 @@ class ForwardTTS(BaseTTS): o_pitch = None avg_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) o_en = o_en + o_pitch_emb # decoder pass - o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too outputs = { "model_outputs": o_de, # [B, T, C] "durations_log": o_dr_log.squeeze(1), # [B, T] @@ -551,22 +583,22 @@ class ForwardTTS(BaseTTS): - x_lengths: [B] - g: [B, C] """ - g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + g = self._set_speaker_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() # encoder pass - o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass - o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) # pitch predictor pass o_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_en = o_en + o_pitch_emb # decoder pass - o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) outputs = { "model_outputs": o_de, "alignments": attn,