diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index a1273f7f..46e1812d 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -213,18 +213,20 @@ class ForwardTTS(BaseTTS): ) self.duration_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.duration_predictor_hidden_channels, self.args.duration_predictor_kernel_size, self.args.duration_predictor_dropout_p, + cond_channels=self.embedded_speaker_dim, ) if self.args.use_pitch: self.pitch_predictor = DurationPredictor( - self.args.hidden_channels + self.embedded_speaker_dim, + self.args.hidden_channels, self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_dropout_p, + cond_channels=self.embedded_speaker_dim, ) self.pitch_emb = nn.Conv1d( 1, @@ -245,24 +247,54 @@ class ForwardTTS(BaseTTS): config (Coqpit): Model configuration. """ self.embedded_speaker_dim = 0 - # init speaker manager - if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) - # set number of speakers - if self.speaker_manager is not None: + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: self.num_speakers = self.speaker_manager.num_speakers - # init d-vector embedding - if config.use_d_vector_file: - self.embedded_speaker_dim = config.d_vector_dim - if self.args.d_vector_dim != self.args.hidden_channels: - self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) - # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") - self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) - nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid = None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = torch.nn.functional.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + return sid, g, lid + + def get_aux_input(self, aux_input: Dict): + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} @staticmethod def generate_attn(dr, x_mask, y_mask=None): @@ -362,10 +394,7 @@ class ForwardTTS(BaseTTS): # encoder pass o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) # speaker conditioning - # TODO: try different ways of conditioning - if g is not None: - o_en = o_en + g - return o_en, x_mask, g, x_emb + return x_emb, x_mask, g, o_en def _forward_decoder( self, @@ -395,7 +424,7 @@ class ForwardTTS(BaseTTS): """ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn = self.expand_encoder_outputs(en=o_en, dr=dr, x_mask=x_mask, y_mask=y_mask) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) @@ -409,6 +438,7 @@ class ForwardTTS(BaseTTS): x_mask: torch.IntTensor, pitch: torch.FloatTensor = None, dr: torch.IntTensor = None, + g: torch.FloatTensor = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """Pitch predictor forward pass. @@ -421,6 +451,7 @@ class ForwardTTS(BaseTTS): x_mask (torch.IntTensor): Input sequence mask. pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. @@ -431,7 +462,7 @@ class ForwardTTS(BaseTTS): - pitch: :math:`(B, 1, T_{de})` - dr: :math:`(B, T_{en})` """ - o_pitch = self.pitch_predictor(o_en, x_mask) + o_pitch = self.pitch_predictor(o_en, x_mask, g=g) if pitch is not None: avg_pitch = average_over_durations(pitch, dr) o_pitch_emb = self.pitch_emb(avg_pitch) @@ -466,19 +497,19 @@ class ForwardTTS(BaseTTS): - x_mask: :math:`[B, 1, T_en]` - y_mask: :math:`[B, 1, T_de]` - - o_alignment_dur: :math:`[B, T_en]` - - alignment_soft: :math:`[B, T_en, T_de]` - - alignment_logprob: :math:`[B, 1, T_de, T_en]` - - alignment_mas: :math:`[B, T_en, T_de]` + - aligner_durations: :math:`[B, T_en]` + - aligner_soft: :math:`[B, T_en, T_de]` + - aligner_logprob: :math:`[B, 1, T_de, T_en]` + - aligner_mas: :math:`[B, T_en, T_de]` """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) - alignment_mas = maximum_path( - alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + aligner_mas = maximum_path( + aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() ) - o_alignment_dur = torch.sum(alignment_mas, -1).int() - alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) - return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + aligner_durations = torch.sum(aligner_mas, -1).int() + aligner_soft = aligner_soft.squeeze(1).transpose(1, 2) + return aligner_durations, aligner_soft, aligner_logprob, aligner_mas def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -525,52 +556,54 @@ class ForwardTTS(BaseTTS): """ g = self._set_speaker_input(aux_input) # compute sequence masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max] # encoder pass - o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] # duration predictor pass if self.args.detach_duration_predictor: - o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max] else: - o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max] o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations - o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2'] # aligner - o_alignment_dur = None - alignment_soft = None - alignment_logprob = None - alignment_mas = None + aligner_durations = None + aligner_soft = None + aligner_logprob = None + aligner_mas = None if self.use_aligner: - o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( - x_emb, y, x_mask, y_mask + # TODO: try passing o_en instead of x_emb + aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( + x=x_emb, y=y, x_mask=x_mask, y_mask=y_mask ) - alignment_soft = alignment_soft.transpose(1, 2) - alignment_mas = alignment_mas.transpose(1, 2) - dr = o_alignment_dur + aligner_soft = aligner_soft.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + dr = aligner_durations # [B, T_max] # pitch predictor pass o_pitch = None avg_pitch = None if self.args.use_pitch: - o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor( + o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g + ) o_en = o_en + o_pitch_emb # decoder pass - o_de, attn = self._forward_decoder( - o_en, dr, x_mask, y_lengths, g=None - ) # TODO: maybe pass speaker embedding (g) too + o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max] outputs = { "model_outputs": o_de, # [B, T, C] + "g": g, # [B, C] "durations_log": o_dr_log.squeeze(1), # [B, T] "durations": o_dr.squeeze(1), # [B, T] - "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de'] "pitch_avg": o_pitch, "pitch_avg_gt": avg_pitch, "alignments": attn, # [B, T_de, T_en] - "alignment_soft": alignment_soft, - "alignment_mas": alignment_mas, - "o_alignment_dur": o_alignment_dur, - "alignment_logprob": alignment_logprob, + "aligner_soft": aligner_soft, + "aligner_mas": aligner_mas, + "aligner_durations": aligner_durations, + "aligner_logprob": aligner_logprob, "x_mask": x_mask, "y_mask": y_mask, } @@ -593,7 +626,7 @@ class ForwardTTS(BaseTTS): x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() # encoder pass - o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + _, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) @@ -610,6 +643,7 @@ class ForwardTTS(BaseTTS): "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log, + "g": g, } return outputs @@ -630,7 +664,7 @@ class ForwardTTS(BaseTTS): ) # use aligner's output as the duration target if self.use_aligner: - durations = outputs["o_alignment_dur"] + durations = outputs["aligner_durations"] # use float32 in AMP with autocast(enabled=False): # compute loss @@ -643,9 +677,9 @@ class ForwardTTS(BaseTTS): pitch_output=outputs["pitch_avg"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, - alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, - alignment_soft=outputs["alignment_soft"], - alignment_hard=outputs["alignment_mas"], + aligner_logprob=outputs["aligner_logprob"] if self.use_aligner else None, + aligner_soft=outputs["aligner_soft"], + aligner_hard=outputs["aligner_mas"], binary_loss_weight=self.binary_loss_weight, ) # compute duration error @@ -655,7 +689,7 @@ class ForwardTTS(BaseTTS): return outputs, loss_dict - def _create_logs(self, batch, outputs, ap): + def create_logs(self, batch, outputs, ap): """Create common logger outputs.""" model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] @@ -694,7 +728,7 @@ class ForwardTTS(BaseTTS): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use - figures, audios = self._create_logs(batch, outputs, self.ap) + figures, audios = self.create_logs(batch, outputs, self.ap) logger.train_figures(steps, figures) logger.train_audios(steps, audios, self.ap.sample_rate) @@ -702,7 +736,7 @@ class ForwardTTS(BaseTTS): return self.train_step(batch, criterion) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._create_logs(batch, outputs, self.ap) + figures, audios = self.create_logs(batch, outputs, self.ap) logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate)