mirror of https://github.com/coqui-ai/TTS.git
Refactor multi-speaker init in ForwardTTS
This commit is contained in:
parent
c125024da0
commit
28a53c7462
|
@ -213,18 +213,20 @@ class ForwardTTS(BaseTTS):
|
|||
)
|
||||
|
||||
self.duration_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.duration_predictor_hidden_channels,
|
||||
self.args.duration_predictor_kernel_size,
|
||||
self.args.duration_predictor_dropout_p,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
self.args.pitch_predictor_kernel_size,
|
||||
self.args.pitch_predictor_dropout_p,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
|
@ -245,24 +247,54 @@ class ForwardTTS(BaseTTS):
|
|||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.args.num_speakers
|
||||
self.audio_transform = None
|
||||
|
||||
if self.speaker_manager:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init d-vector embedding
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
if self.args.use_speaker_embedding:
|
||||
self._init_speaker_embedding()
|
||||
|
||||
if self.args.use_d_vector_file:
|
||||
self._init_d_vector()
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g, lid = None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = torch.nn.functional.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||
if g.ndim == 2:
|
||||
g = g.unsqueeze_(0)
|
||||
|
||||
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
||||
lid = aux_input["language_ids"]
|
||||
if lid.ndim == 0:
|
||||
lid = lid.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
|
@ -362,10 +394,7 @@ class ForwardTTS(BaseTTS):
|
|||
# encoder pass
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
# speaker conditioning
|
||||
# TODO: try different ways of conditioning
|
||||
if g is not None:
|
||||
o_en = o_en + g
|
||||
return o_en, x_mask, g, x_emb
|
||||
return x_emb, x_mask, g, o_en
|
||||
|
||||
def _forward_decoder(
|
||||
self,
|
||||
|
@ -395,7 +424,7 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
o_en_ex, attn = self.expand_encoder_outputs(en=o_en, dr=dr, x_mask=x_mask, y_mask=y_mask)
|
||||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
@ -409,6 +438,7 @@ class ForwardTTS(BaseTTS):
|
|||
x_mask: torch.IntTensor,
|
||||
pitch: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
g: torch.FloatTensor = None,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Pitch predictor forward pass.
|
||||
|
||||
|
@ -421,6 +451,7 @@ class ForwardTTS(BaseTTS):
|
|||
x_mask (torch.IntTensor): Input sequence mask.
|
||||
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||
|
@ -431,7 +462,7 @@ class ForwardTTS(BaseTTS):
|
|||
- pitch: :math:`(B, 1, T_{de})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
"""
|
||||
o_pitch = self.pitch_predictor(o_en, x_mask)
|
||||
o_pitch = self.pitch_predictor(o_en, x_mask, g=g)
|
||||
if pitch is not None:
|
||||
avg_pitch = average_over_durations(pitch, dr)
|
||||
o_pitch_emb = self.pitch_emb(avg_pitch)
|
||||
|
@ -466,19 +497,19 @@ class ForwardTTS(BaseTTS):
|
|||
- x_mask: :math:`[B, 1, T_en]`
|
||||
- y_mask: :math:`[B, 1, T_de]`
|
||||
|
||||
- o_alignment_dur: :math:`[B, T_en]`
|
||||
- alignment_soft: :math:`[B, T_en, T_de]`
|
||||
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
|
||||
- alignment_mas: :math:`[B, T_en, T_de]`
|
||||
- aligner_durations: :math:`[B, T_en]`
|
||||
- aligner_soft: :math:`[B, T_en, T_de]`
|
||||
- aligner_logprob: :math:`[B, 1, T_de, T_en]`
|
||||
- aligner_mas: :math:`[B, T_en, T_de]`
|
||||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
|
||||
alignment_mas = maximum_path(
|
||||
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||
aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
|
||||
aligner_mas = maximum_path(
|
||||
aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||
)
|
||||
o_alignment_dur = torch.sum(alignment_mas, -1).int()
|
||||
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
||||
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
||||
aligner_durations = torch.sum(aligner_mas, -1).int()
|
||||
aligner_soft = aligner_soft.squeeze(1).transpose(1, 2)
|
||||
return aligner_durations, aligner_soft, aligner_logprob, aligner_mas
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
|
@ -525,52 +556,54 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
g = self._set_speaker_input(aux_input)
|
||||
# compute sequence masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
|
||||
# encoder pass
|
||||
o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||
x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
|
||||
# duration predictor pass
|
||||
if self.args.detach_duration_predictor:
|
||||
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
||||
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
else:
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
# generate attn mask from predicted durations
|
||||
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||
dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2']
|
||||
# aligner
|
||||
o_alignment_dur = None
|
||||
alignment_soft = None
|
||||
alignment_logprob = None
|
||||
alignment_mas = None
|
||||
aligner_durations = None
|
||||
aligner_soft = None
|
||||
aligner_logprob = None
|
||||
aligner_mas = None
|
||||
if self.use_aligner:
|
||||
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
|
||||
x_emb, y, x_mask, y_mask
|
||||
# TODO: try passing o_en instead of x_emb
|
||||
aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner(
|
||||
x=x_emb, y=y, x_mask=x_mask, y_mask=y_mask
|
||||
)
|
||||
alignment_soft = alignment_soft.transpose(1, 2)
|
||||
alignment_mas = alignment_mas.transpose(1, 2)
|
||||
dr = o_alignment_dur
|
||||
aligner_soft = aligner_soft.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max]
|
||||
aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max]
|
||||
dr = aligner_durations # [B, T_max]
|
||||
# pitch predictor pass
|
||||
o_pitch = None
|
||||
avg_pitch = None
|
||||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(
|
||||
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
|
||||
)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(
|
||||
o_en, dr, x_mask, y_lengths, g=None
|
||||
) # TODO: maybe pass speaker embedding (g) too
|
||||
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max]
|
||||
outputs = {
|
||||
"model_outputs": o_de, # [B, T, C]
|
||||
"g": g, # [B, C]
|
||||
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
||||
"durations": o_dr.squeeze(1), # [B, T]
|
||||
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
|
||||
"attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de']
|
||||
"pitch_avg": o_pitch,
|
||||
"pitch_avg_gt": avg_pitch,
|
||||
"alignments": attn, # [B, T_de, T_en]
|
||||
"alignment_soft": alignment_soft,
|
||||
"alignment_mas": alignment_mas,
|
||||
"o_alignment_dur": o_alignment_dur,
|
||||
"alignment_logprob": alignment_logprob,
|
||||
"aligner_soft": aligner_soft,
|
||||
"aligner_mas": aligner_mas,
|
||||
"aligner_durations": aligner_durations,
|
||||
"aligner_logprob": aligner_logprob,
|
||||
"x_mask": x_mask,
|
||||
"y_mask": y_mask,
|
||||
}
|
||||
|
@ -593,7 +626,7 @@ class ForwardTTS(BaseTTS):
|
|||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||
# encoder pass
|
||||
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||
_, x_mask, g, o_en = self._forward_encoder(x, x_mask, g)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
|
@ -610,6 +643,7 @@ class ForwardTTS(BaseTTS):
|
|||
"alignments": attn,
|
||||
"pitch": o_pitch,
|
||||
"durations_log": o_dr_log,
|
||||
"g": g,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -630,7 +664,7 @@ class ForwardTTS(BaseTTS):
|
|||
)
|
||||
# use aligner's output as the duration target
|
||||
if self.use_aligner:
|
||||
durations = outputs["o_alignment_dur"]
|
||||
durations = outputs["aligner_durations"]
|
||||
# use float32 in AMP
|
||||
with autocast(enabled=False):
|
||||
# compute loss
|
||||
|
@ -643,9 +677,9 @@ class ForwardTTS(BaseTTS):
|
|||
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
||||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||
input_lens=text_lengths,
|
||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||
alignment_soft=outputs["alignment_soft"],
|
||||
alignment_hard=outputs["alignment_mas"],
|
||||
aligner_logprob=outputs["aligner_logprob"] if self.use_aligner else None,
|
||||
aligner_soft=outputs["aligner_soft"],
|
||||
aligner_hard=outputs["aligner_mas"],
|
||||
binary_loss_weight=self.binary_loss_weight,
|
||||
)
|
||||
# compute duration error
|
||||
|
@ -655,7 +689,7 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
return outputs, loss_dict
|
||||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
def create_logs(self, batch, outputs, ap):
|
||||
"""Create common logger outputs."""
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
|
@ -694,7 +728,7 @@ class ForwardTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
figures, audios = self.create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
@ -702,7 +736,7 @@ class ForwardTTS(BaseTTS):
|
|||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
figures, audios = self.create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
|
Loading…
Reference in New Issue