mirror of https://github.com/coqui-ai/TTS.git
refactor: move duplicate alignment functions into helpers
This commit is contained in:
parent
8bf288eeab
commit
7330ad8854
|
@ -12,7 +12,6 @@ from TTS.tts.layers.delightful_tts.conformer import Conformer
|
|||
from TTS.tts.layers.delightful_tts.encoders import (
|
||||
PhonemeLevelProsodyEncoder,
|
||||
UtteranceLevelProsodyEncoder,
|
||||
get_mask_from_lengths,
|
||||
)
|
||||
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor
|
||||
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding
|
||||
|
@ -20,7 +19,7 @@ from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProso
|
|||
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor
|
||||
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
||||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||
from TTS.tts.utils.helpers import generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -231,42 +230,6 @@ class AcousticModel(torch.nn.Module):
|
|||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the linear scale durations.
|
||||
|
||||
Args:
|
||||
dr (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def _expand_encoder_with_durations(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
dr: torch.IntTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
y_lengths: torch.IntTensor,
|
||||
):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en])
|
||||
return y_mask, o_en_ex, attn.transpose(1, 2)
|
||||
|
||||
def _forward_aligner(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
|
@ -340,8 +303,8 @@ class AcousticModel(torch.nn.Module):
|
|||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
) # pylint: disable=unused-variable
|
||||
|
||||
src_mask = get_mask_from_lengths(src_lens) # [B, T_src]
|
||||
mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel]
|
||||
src_mask = ~sequence_mask(src_lens) # [B, T_src]
|
||||
mel_mask = ~sequence_mask(mel_lens) # [B, T_mel]
|
||||
|
||||
# Token embeddings
|
||||
token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden]
|
||||
|
@ -420,8 +383,8 @@ class AcousticModel(torch.nn.Module):
|
|||
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb
|
||||
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask)
|
||||
|
||||
mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None]
|
||||
encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs(
|
||||
encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
x = self.decoder(
|
||||
|
@ -435,7 +398,7 @@ class AcousticModel(torch.nn.Module):
|
|||
dr = torch.log(dr + 1)
|
||||
|
||||
dr_pred = torch.exp(log_duration_prediction) - 1
|
||||
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
|
||||
alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
|
||||
|
||||
return {
|
||||
"model_outputs": x,
|
||||
|
@ -448,7 +411,7 @@ class AcousticModel(torch.nn.Module):
|
|||
"p_prosody_pred": p_prosody_pred,
|
||||
"p_prosody_ref": p_prosody_ref,
|
||||
"alignments_dp": alignments_dp,
|
||||
"alignments": alignments, # [B, T_de, T_en]
|
||||
"alignments": alignments.transpose(1, 2), # [B, T_de, T_en]
|
||||
"aligner_soft": aligner_soft,
|
||||
"aligner_mas": aligner_mas,
|
||||
"aligner_durations": aligner_durations,
|
||||
|
@ -469,7 +432,7 @@ class AcousticModel(torch.nn.Module):
|
|||
pitch_transform: Callable = None,
|
||||
energy_transform: Callable = None,
|
||||
) -> torch.Tensor:
|
||||
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
|
||||
src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
|
||||
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable
|
||||
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
|
||||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
|
@ -536,11 +499,11 @@ class AcousticModel(torch.nn.Module):
|
|||
duration_pred = torch.round(duration_pred) # -> [B, T_src]
|
||||
mel_lens = duration_pred.sum(1) # -> [B,]
|
||||
|
||||
_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
|
||||
encoder_outputs_ex, alignments, _ = expand_encoder_outputs(
|
||||
encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
mel_mask = get_mask_from_lengths(
|
||||
mel_mask = ~sequence_mask(
|
||||
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device)
|
||||
)
|
||||
|
||||
|
@ -557,7 +520,7 @@ class AcousticModel(torch.nn.Module):
|
|||
x = self.to_mel(x)
|
||||
outputs = {
|
||||
"model_outputs": x,
|
||||
"alignments": alignments,
|
||||
"alignments": alignments.transpose(1, 2),
|
||||
# "pitch": pitch_emb_pred,
|
||||
"durations": duration_pred,
|
||||
"pitch": pitch_pred,
|
||||
|
|
|
@ -7,14 +7,7 @@ import torch.nn.functional as F
|
|||
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
|
||||
from TTS.tts.layers.delightful_tts.networks import STL
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
return mask
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
|
@ -93,7 +86,7 @@ class ReferenceEncoder(nn.Module):
|
|||
outputs --- [N, E//2]
|
||||
"""
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
|
||||
mel_masks = ~sequence_mask(mel_lens).unsqueeze(1)
|
||||
x = x.masked_fill(mel_masks, 0)
|
||||
for conv, norm in zip(self.convs, self.norms):
|
||||
x = conv(x)
|
||||
|
@ -103,7 +96,7 @@ class ReferenceEncoder(nn.Module):
|
|||
for _ in range(2):
|
||||
mel_lens = stride_lens(mel_lens)
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens)
|
||||
mel_masks = ~sequence_mask(mel_lens)
|
||||
|
||||
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
|
||||
x = x.permute((0, 2, 1))
|
||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -169,35 +169,6 @@ class AlignTTS(BaseTTS):
|
|||
dr_mas = torch.sum(attn, -1)
|
||||
return dr_mas.squeeze(1), log_p
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Examples::
|
||||
- encoder output: [a,b,c,d]
|
||||
- durations: [1, 3, 2, 1]
|
||||
|
||||
- expanded: [a, b, b, b, c, c, d]
|
||||
- attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
||||
o_dr[o_dr < 1] = 1.0
|
||||
|
@ -243,9 +214,8 @@ class AlignTTS(BaseTTS):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
|
||||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
@ -282,7 +252,7 @@ class AlignTTS(BaseTTS):
|
|||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
attn = self.generate_attn(dr_mas, x_mask, y_mask)
|
||||
attn = generate_attention(dr_mas, x_mask, y_mask)
|
||||
elif phase == 1:
|
||||
# train decoder
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
|
|
|
@ -14,7 +14,7 @@ from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
|||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||
|
@ -310,49 +310,6 @@ class ForwardTTS(BaseTTS):
|
|||
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the durations.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Shapes:
|
||||
- en: :math:`(B, D_{en}, T_{en})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
|
||||
Examples::
|
||||
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
"""Format predicted durations.
|
||||
1. Convert to linear scale from log scale
|
||||
|
@ -443,9 +400,8 @@ class ForwardTTS(BaseTTS):
|
|||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||
"""
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
|
||||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
@ -624,7 +580,7 @@ class ForwardTTS(BaseTTS):
|
|||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
# generate attn mask from predicted durations
|
||||
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||
o_attn = generate_attention(o_dr.squeeze(1), x_mask)
|
||||
# aligner
|
||||
o_alignment_dur = None
|
||||
alignment_soft = None
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import betabinom
|
||||
|
@ -33,7 +35,7 @@ class StandardScaler:
|
|||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
def sequence_mask(sequence_length: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
|
||||
"""Create a sequence mask for filtering padding in a sequence tensor.
|
||||
|
||||
Args:
|
||||
|
@ -44,7 +46,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
- mask: :math:`[B, T_max]`
|
||||
"""
|
||||
if max_len is None:
|
||||
max_len = sequence_length.max()
|
||||
max_len = int(sequence_length.max())
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
@ -143,22 +145,75 @@ def convert_pad_shape(pad_shape: list[list]) -> list:
|
|||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate alignment path based on the given segment durations.
|
||||
|
||||
Shapes:
|
||||
- duration: :math:`[B, T_en]`
|
||||
- mask: :math:'[B, T_en, T_de]`
|
||||
- path: :math:`[B, T_en, T_de]`
|
||||
"""
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
cum_duration = torch.cumsum(duration, dim=1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
return path * mask
|
||||
|
||||
|
||||
def generate_attention(
|
||||
duration: torch.Tensor, x_mask: torch.Tensor, y_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Generate an attention map from the linear scale durations.
|
||||
|
||||
Args:
|
||||
duration (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- duration: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = duration.sum(dim=1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = sequence_mask(y_lengths).unsqueeze(1).to(duration.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
return generate_path(duration, attn_mask.squeeze(1)).to(duration.dtype)
|
||||
|
||||
|
||||
def expand_encoder_outputs(
|
||||
x: torch.Tensor, duration: torch.Tensor, x_mask: torch.Tensor, y_lengths: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate attention alignment map from durations and expand encoder outputs.
|
||||
|
||||
Shapes:
|
||||
- x: Encoder output :math:`(B, D_{en}, T_{en})`
|
||||
- duration: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_lengths: :math:`(B)`
|
||||
|
||||
Examples::
|
||||
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
y_mask = sequence_mask(y_lengths).unsqueeze(1).to(x.dtype)
|
||||
attn = generate_attention(duration, x_mask, y_mask)
|
||||
x_expanded = torch.einsum("kmn, kjm -> kjn", [attn.float(), x])
|
||||
return x_expanded, attn, y_mask
|
||||
|
||||
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
import torch as T
|
||||
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.helpers import (
|
||||
average_over_durations,
|
||||
expand_encoder_outputs,
|
||||
generate_attention,
|
||||
generate_path,
|
||||
rand_segments,
|
||||
segment,
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_average_over_durations(): # pylint: disable=no-self-use
|
||||
|
@ -86,3 +94,24 @@ def test_generate_path():
|
|||
assert all(path[b, t, :current_idx] == 0.0)
|
||||
assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0)
|
||||
current_idx += durations[b, t].item()
|
||||
|
||||
assert T.all(path == generate_attention(durations, x_mask, y_mask))
|
||||
assert T.all(path == generate_attention(durations, x_mask))
|
||||
|
||||
|
||||
def test_expand_encoder_outputs():
|
||||
inputs = T.rand(2, 5, 57)
|
||||
durations = T.randint(1, 4, (2, 57))
|
||||
|
||||
x_mask = T.ones(2, 1, 57)
|
||||
y_lengths = T.ones(2) * durations.sum(1).max()
|
||||
|
||||
expanded, _, _ = expand_encoder_outputs(inputs, durations, x_mask, y_lengths)
|
||||
|
||||
for b in range(durations.shape[0]):
|
||||
index = 0
|
||||
for idx, dur in enumerate(durations[b]):
|
||||
idx_expanded = expanded[b, :, index : index + dur.item()]
|
||||
diff = (idx_expanded - inputs[b, :, idx].repeat(int(dur)).view(idx_expanded.shape)).sum()
|
||||
assert abs(diff) < 1e-6, diff
|
||||
index += dur
|
||||
|
|
|
@ -6,29 +6,7 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
# pylint: disable=unused-variable
|
||||
|
||||
|
||||
def expand_encoder_outputs_test():
|
||||
model = ForwardTTS(ForwardTTSArgs(num_chars=10))
|
||||
|
||||
inputs = T.rand(2, 5, 57)
|
||||
durations = T.randint(1, 4, (2, 57))
|
||||
|
||||
x_mask = T.ones(2, 1, 57)
|
||||
y_mask = T.ones(2, 1, durations.sum(1).max())
|
||||
|
||||
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
|
||||
|
||||
for b in range(durations.shape[0]):
|
||||
index = 0
|
||||
for idx, dur in enumerate(durations[b]):
|
||||
diff = (
|
||||
expanded[b, :, index : index + dur.item()]
|
||||
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
|
||||
).sum()
|
||||
assert abs(diff) < 1e-6, diff
|
||||
index += dur
|
||||
|
||||
|
||||
def model_input_output_test():
|
||||
def test_model_input_output():
|
||||
"""Assert the output shapes of the model in different modes"""
|
||||
|
||||
# VANILLA MODEL
|
||||
|
|
Loading…
Reference in New Issue