mirror of https://github.com/coqui-ai/TTS.git
678 lines
24 KiB
Python
678 lines
24 KiB
Python
from dataclasses import dataclass, field
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from coqpit import Coqpit
|
|
|
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
|
from TTS.tts.models.base_tts import BaseTTS
|
|
from TTS.tts.utils.data import sequence_mask
|
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
from TTS.utils.audio import AudioProcessor
|
|
|
|
# pylint: disable=dangerous-default-value
|
|
|
|
|
|
def mask_from_lens(lens, max_len: int = None):
|
|
if max_len is None:
|
|
max_len = lens.max()
|
|
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
|
|
mask = torch.lt(ids, lens.unsqueeze(1))
|
|
return mask
|
|
|
|
|
|
class LinearNorm(torch.nn.Module):
|
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
|
super(LinearNorm, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
|
|
|
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
|
|
|
def forward(self, x):
|
|
return self.linear_layer(x)
|
|
|
|
|
|
class ConvNorm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=None,
|
|
dilation=1,
|
|
bias=True,
|
|
w_init_gain="linear",
|
|
batch_norm=False,
|
|
):
|
|
super(ConvNorm, self).__init__()
|
|
if padding is None:
|
|
assert kernel_size % 2 == 1
|
|
padding = int(dilation * (kernel_size - 1) / 2)
|
|
|
|
self.conv = torch.nn.Conv1d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
)
|
|
self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None
|
|
|
|
torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
|
|
|
def forward(self, signal):
|
|
if self.norm is None:
|
|
return self.conv(signal)
|
|
else:
|
|
return self.norm(self.conv(signal))
|
|
|
|
|
|
class ConvReLUNorm(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0):
|
|
super(ConvReLUNorm, self).__init__()
|
|
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2))
|
|
self.norm = torch.nn.LayerNorm(out_channels)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
|
|
def forward(self, signal):
|
|
out = F.relu(self.conv(signal))
|
|
out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype)
|
|
return self.dropout(out)
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, demb):
|
|
super(PositionalEmbedding, self).__init__()
|
|
self.demb = demb
|
|
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
def forward(self, pos_seq, bsz=None):
|
|
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0))
|
|
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
|
if bsz is not None:
|
|
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
|
else:
|
|
return pos_emb[None, :, :]
|
|
|
|
|
|
class PositionwiseConvFF(nn.Module):
|
|
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
|
super(PositionwiseConvFF, self).__init__()
|
|
|
|
self.d_model = d_model
|
|
self.d_inner = d_inner
|
|
self.dropout = dropout
|
|
|
|
self.CoreNet = nn.Sequential(
|
|
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
|
nn.ReLU(),
|
|
# nn.Dropout(dropout), # worse convergence
|
|
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
|
nn.Dropout(dropout),
|
|
)
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
def forward(self, inp):
|
|
return self._forward(inp)
|
|
|
|
def _forward(self, inp):
|
|
if self.pre_lnorm:
|
|
# layer normalization + positionwise feed-forward
|
|
core_out = inp.transpose(1, 2)
|
|
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
|
|
core_out = core_out.transpose(1, 2)
|
|
|
|
# residual connection
|
|
output = core_out + inp
|
|
else:
|
|
# positionwise feed-forward
|
|
core_out = inp.transpose(1, 2)
|
|
core_out = self.CoreNet(core_out)
|
|
core_out = core_out.transpose(1, 2)
|
|
|
|
# residual connection + layer normalization
|
|
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
|
|
|
return output
|
|
|
|
|
|
class MultiHeadAttn(nn.Module):
|
|
def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False):
|
|
super(MultiHeadAttn, self).__init__()
|
|
|
|
self.num_heads = num_heads
|
|
self.d_model = d_model
|
|
self.hidden_channels_head = hidden_channels_head
|
|
self.scale = 1 / (hidden_channels_head ** 0.5)
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head)
|
|
self.drop = nn.Dropout(dropout)
|
|
self.dropout_attn = nn.Dropout(dropout_attn)
|
|
self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False)
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, inp, attn_mask=None):
|
|
return self._forward(inp, attn_mask)
|
|
|
|
def _forward(self, inp, attn_mask=None):
|
|
residual = inp
|
|
|
|
if self.pre_lnorm:
|
|
# layer normalization
|
|
inp = self.layer_norm(inp)
|
|
|
|
num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head
|
|
|
|
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
|
head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
|
|
head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
|
|
head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
|
|
|
|
q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
|
|
k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
|
|
v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
|
|
|
|
attn_score = torch.bmm(q, k.transpose(1, 2))
|
|
attn_score.mul_(self.scale)
|
|
|
|
if attn_mask is not None:
|
|
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
|
attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1)
|
|
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
|
|
|
|
attn_prob = F.softmax(attn_score, dim=2)
|
|
attn_prob = self.dropout_attn(attn_prob)
|
|
attn_vec = torch.bmm(attn_prob, v)
|
|
|
|
attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head)
|
|
attn_vec = (
|
|
attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head)
|
|
)
|
|
|
|
# linear projection
|
|
attn_out = self.o_net(attn_vec)
|
|
attn_out = self.drop(attn_out)
|
|
|
|
if self.pre_lnorm:
|
|
# residual connection
|
|
output = residual + attn_out
|
|
else:
|
|
# residual connection + layer normalization
|
|
output = self.layer_norm(residual + attn_out)
|
|
|
|
output = output.to(attn_out.dtype)
|
|
|
|
return output
|
|
|
|
|
|
class TransformerLayer(nn.Module):
|
|
def __init__(
|
|
self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs
|
|
):
|
|
super(TransformerLayer, self).__init__()
|
|
|
|
self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs)
|
|
self.pos_ff = PositionwiseConvFF(
|
|
hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm")
|
|
)
|
|
|
|
def forward(self, dec_inp, mask=None):
|
|
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
|
output *= mask
|
|
output = self.pos_ff(output)
|
|
output *= mask
|
|
return output
|
|
|
|
|
|
class FFTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers,
|
|
num_heads,
|
|
hidden_channels,
|
|
hidden_channels_head,
|
|
hidden_channels_ffn,
|
|
kernel_size,
|
|
dropout,
|
|
dropout_attn,
|
|
dropemb=0.0,
|
|
pre_lnorm=False,
|
|
):
|
|
super(FFTransformer, self).__init__()
|
|
self.hidden_channels = hidden_channels
|
|
self.num_heads = num_heads
|
|
self.hidden_channels_head = hidden_channels_head
|
|
|
|
self.pos_emb = PositionalEmbedding(self.hidden_channels)
|
|
self.drop = nn.Dropout(dropemb)
|
|
self.layers = nn.ModuleList()
|
|
|
|
for _ in range(num_layers):
|
|
self.layers.append(
|
|
TransformerLayer(
|
|
num_heads,
|
|
hidden_channels,
|
|
hidden_channels_head,
|
|
hidden_channels_ffn,
|
|
kernel_size,
|
|
dropout,
|
|
dropout_attn=dropout_attn,
|
|
pre_lnorm=pre_lnorm,
|
|
)
|
|
)
|
|
|
|
def forward(self, x, x_lengths, conditioning=0):
|
|
mask = mask_from_lens(x_lengths).unsqueeze(2)
|
|
|
|
pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype)
|
|
pos_emb = self.pos_emb(pos_seq) * mask
|
|
|
|
if conditioning is None:
|
|
conditioning = 0
|
|
|
|
out = self.drop(x + pos_emb + conditioning)
|
|
|
|
for layer in self.layers:
|
|
out = layer(out, mask=mask)
|
|
|
|
# out = self.drop(out)
|
|
return out, mask
|
|
|
|
|
|
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
|
|
"""If target=None, then predicted durations are applied"""
|
|
dtype = enc_out.dtype
|
|
reps = durations.float() / pace
|
|
reps = (reps + 0.5).long()
|
|
dec_lens = reps.sum(dim=1)
|
|
|
|
max_len = dec_lens.max()
|
|
reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
|
reps_cumsum = reps_cumsum.to(dtype)
|
|
|
|
range_ = torch.arange(max_len).to(enc_out.device)[None, :, None]
|
|
mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_)
|
|
mult = mult.to(dtype)
|
|
en_ex = torch.matmul(mult, enc_out)
|
|
|
|
if mel_max_len:
|
|
en_ex = en_ex[:, :mel_max_len]
|
|
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
|
return en_ex, dec_lens
|
|
|
|
|
|
class TemporalPredictor(nn.Module):
|
|
"""Predicts a single float per each temporal location"""
|
|
|
|
def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2):
|
|
super(TemporalPredictor, self).__init__()
|
|
|
|
self.layers = nn.Sequential(
|
|
*[
|
|
ConvReLUNorm(
|
|
input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
self.fc = nn.Linear(filter_size, 1, bias=True)
|
|
|
|
def forward(self, enc_out, enc_out_mask):
|
|
out = enc_out * enc_out_mask
|
|
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
|
out = self.fc(out) * enc_out_mask
|
|
return out.squeeze(-1)
|
|
|
|
|
|
@dataclass
|
|
class FastPitchArgs(Coqpit):
|
|
num_chars: int = 100
|
|
out_channels: int = 80
|
|
hidden_channels: int = 384
|
|
num_speakers: int = 0
|
|
duration_predictor_hidden_channels: int = 256
|
|
duration_predictor_dropout: float = 0.1
|
|
duration_predictor_kernel_size: int = 3
|
|
duration_predictor_dropout_p: float = 0.1
|
|
duration_predictor_num_layers: int = 2
|
|
pitch_predictor_hidden_channels: int = 256
|
|
pitch_predictor_dropout: float = 0.1
|
|
pitch_predictor_kernel_size: int = 3
|
|
pitch_predictor_dropout_p: float = 0.1
|
|
pitch_embedding_kernel_size: int = 3
|
|
pitch_predictor_num_layers: int = 2
|
|
positional_encoding: bool = True
|
|
length_scale: int = 1
|
|
encoder_type: str = "fftransformer"
|
|
encoder_params: dict = field(
|
|
default_factory=lambda: {
|
|
"hidden_channels_head": 64,
|
|
"hidden_channels_ffn": 1536,
|
|
"num_heads": 1,
|
|
"num_layers": 6,
|
|
"kernel_size": 3,
|
|
"dropout": 0.1,
|
|
"dropout_attn": 0.1,
|
|
}
|
|
)
|
|
decoder_type: str = "fftransformer"
|
|
decoder_params: dict = field(
|
|
default_factory=lambda: {
|
|
"hidden_channels_head": 64,
|
|
"hidden_channels_ffn": 1536,
|
|
"num_heads": 1,
|
|
"num_layers": 6,
|
|
"kernel_size": 3,
|
|
"dropout": 0.1,
|
|
"dropout_attn": 0.1,
|
|
}
|
|
)
|
|
use_d_vector: bool = False
|
|
d_vector_dim: int = 0
|
|
detach_duration_predictor: bool = False
|
|
max_duration: int = 75
|
|
use_gt_duration: bool = True
|
|
|
|
|
|
class FastPitch(BaseTTS):
|
|
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
|
|
|
|
Paper abstract:
|
|
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
|
|
frequency contours. The model predicts pitch contours during inference. By altering these predictions,
|
|
the generated speech can be more expressive, better match the semantic of the utterance, and in the end
|
|
more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech
|
|
that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall
|
|
quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead,
|
|
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
|
|
factor for mel-spectrogram synthesis of a typical utterance."
|
|
|
|
Notes:
|
|
TODO
|
|
|
|
Args:
|
|
config (Coqpit): Model coqpit class.
|
|
|
|
Examples:
|
|
>>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs
|
|
>>> config = FastPitchArgs()
|
|
>>> model = FastPitch(config)
|
|
"""
|
|
|
|
def __init__(self, config: Coqpit):
|
|
super().__init__()
|
|
|
|
if "characters" in config:
|
|
# loading from FasrPitchConfig
|
|
_, self.config, num_chars = self.get_characters(config)
|
|
config.model_args.num_chars = num_chars
|
|
args = self.config.model_args
|
|
else:
|
|
# loading from FastPitchArgs
|
|
self.config = config
|
|
args = config
|
|
|
|
self.max_duration = args.max_duration
|
|
self.use_gt_duration = args.use_gt_duration
|
|
|
|
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale
|
|
|
|
self.encoder = FFTransformer(
|
|
hidden_channels=args.hidden_channels,
|
|
**args.encoder_params,
|
|
)
|
|
|
|
# if n_speakers > 1:
|
|
# self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
|
# else:
|
|
# self.speaker_emb = None
|
|
# self.speaker_emb_weight = speaker_emb_weight
|
|
self.emb = nn.Embedding(args.num_chars, args.hidden_channels)
|
|
|
|
self.duration_predictor = TemporalPredictor(
|
|
args.hidden_channels,
|
|
filter_size=args.duration_predictor_hidden_channels,
|
|
kernel_size=args.duration_predictor_kernel_size,
|
|
dropout=args.duration_predictor_dropout_p,
|
|
num_layers=args.duration_predictor_num_layers,
|
|
)
|
|
|
|
self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params)
|
|
|
|
self.pitch_predictor = TemporalPredictor(
|
|
args.hidden_channels,
|
|
filter_size=args.pitch_predictor_hidden_channels,
|
|
kernel_size=args.pitch_predictor_kernel_size,
|
|
dropout=args.pitch_predictor_dropout_p,
|
|
num_layers=args.pitch_predictor_num_layers,
|
|
)
|
|
|
|
self.pitch_emb = nn.Conv1d(
|
|
1,
|
|
args.hidden_channels,
|
|
kernel_size=args.pitch_embedding_kernel_size,
|
|
padding=int((args.pitch_embedding_kernel_size - 1) / 2),
|
|
)
|
|
|
|
self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True)
|
|
|
|
@staticmethod
|
|
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
|
"""Generate attention alignment map from durations and
|
|
expand encoder outputs
|
|
|
|
Example:
|
|
encoder output: [a,b,c,d]
|
|
durations: [1, 3, 2, 1]
|
|
|
|
expanded: [a, b, b, b, c, c, d]
|
|
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 1, 1, 0],
|
|
[0, 1, 1, 1, 0, 0, 0],
|
|
[1, 0, 0, 0, 0, 0, 0]]
|
|
"""
|
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
|
o_en_ex = torch.matmul(attn.transpose(1, 2), en)
|
|
return o_en_ex, attn.transpose(1, 2)
|
|
|
|
def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}):
|
|
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
|
|
|
# Calculate speaker embedding
|
|
# if self.speaker_emb is None:
|
|
# speaker_embedding = 0
|
|
# else:
|
|
# speaker_embedding = self.speaker_emb(speaker).unsqueeze(1)
|
|
# speaker_embedding.mul_(self.speaker_emb_weight)
|
|
|
|
# character embedding
|
|
embedding = self.emb(x)
|
|
|
|
# Input FFT
|
|
o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding)
|
|
|
|
# Embedded for predictors
|
|
o_en_dr, mask_en_dr = o_en, mask_en
|
|
|
|
# Predict durations
|
|
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
|
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
|
|
|
# TODO: move this to the dataset
|
|
avg_pitch = average_pitch(pitch, dr)
|
|
|
|
# Predict pitch
|
|
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
|
|
pitch_emb = self.pitch_emb(avg_pitch)
|
|
o_en = o_en + pitch_emb.transpose(1, 2)
|
|
|
|
# len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len)
|
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
|
|
|
# Output FFT
|
|
o_de, _ = self.decoder(o_en_ex, y_lengths)
|
|
o_de = self.proj(o_de)
|
|
outputs = {
|
|
"model_outputs": o_de,
|
|
"durations_log": o_dr_log.squeeze(1),
|
|
"durations": o_dr.squeeze(1),
|
|
"pitch": o_pitch,
|
|
"pitch_gt": avg_pitch,
|
|
"alignments": attn,
|
|
}
|
|
return outputs
|
|
|
|
@torch.no_grad()
|
|
def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument
|
|
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
|
|
|
|
# input sequence should be greated than the max convolution size
|
|
inference_padding = 5
|
|
if x.shape[1] < 13:
|
|
inference_padding += 13 - x.shape[1]
|
|
|
|
# pad input to prevent dropping the last word
|
|
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
|
|
|
# character embedding
|
|
embedding = self.emb(x)
|
|
|
|
# if self.speaker_emb is None:
|
|
# else:
|
|
# speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker
|
|
# spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
|
# spk_emb.mul_(self.speaker_emb_weight)
|
|
|
|
# Input FFT
|
|
o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding)
|
|
|
|
# Predict durations
|
|
o_dr_log = self.duration_predictor(o_en, mask_en)
|
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
|
o_dr = o_dr * self.length_scale
|
|
|
|
# Pitch over chars
|
|
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
|
|
|
|
# if pitch_transform is not None:
|
|
# if self.pitch_std[0] == 0.0:
|
|
# # XXX LJSpeech-1.1 defaults
|
|
# mean, std = 218.14, 67.24
|
|
# else:
|
|
# mean, std = self.pitch_mean[0], self.pitch_std[0]
|
|
# pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std)
|
|
|
|
o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2)
|
|
|
|
o_en = o_en + o_pitch_emb
|
|
|
|
y_lengths = o_dr.sum(1)
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
|
|
|
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask)
|
|
o_de, _ = self.decoder(o_en_ex, y_lengths)
|
|
o_de = self.proj(o_de)
|
|
|
|
outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log}
|
|
return outputs
|
|
|
|
def train_step(self, batch: dict, criterion: nn.Module):
|
|
text_input = batch["text_input"]
|
|
text_lengths = batch["text_lengths"]
|
|
mel_input = batch["mel_input"]
|
|
mel_lengths = batch["mel_lengths"]
|
|
pitch = batch["pitch"]
|
|
d_vectors = batch["d_vectors"]
|
|
speaker_ids = batch["speaker_ids"]
|
|
durations = batch["durations"]
|
|
|
|
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
|
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input)
|
|
|
|
# compute loss
|
|
loss_dict = criterion(
|
|
outputs["model_outputs"],
|
|
mel_input,
|
|
mel_lengths,
|
|
outputs["durations_log"],
|
|
durations,
|
|
outputs["pitch"],
|
|
outputs["pitch_gt"],
|
|
text_lengths,
|
|
)
|
|
|
|
# compute duration error
|
|
durations_pred = outputs["durations"]
|
|
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
|
loss_dict["duration_error"] = duration_error
|
|
return outputs, loss_dict
|
|
|
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
|
model_outputs = outputs["model_outputs"]
|
|
alignments = outputs["alignments"]
|
|
mel_input = batch["mel_input"]
|
|
|
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
|
gt_spec = mel_input[0].data.cpu().numpy()
|
|
align_img = alignments[0].data.cpu().numpy()
|
|
|
|
figures = {
|
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
|
"alignment": plot_alignment(align_img, output_fig=False),
|
|
}
|
|
|
|
# Sample audio
|
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
|
return figures, {"audio": train_audio}
|
|
|
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
|
return self.train_step(batch, criterion)
|
|
|
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
|
return self.train_log(ap, batch, outputs)
|
|
|
|
def load_checkpoint(
|
|
self, config, checkpoint_path, eval=False
|
|
): # pylint: disable=unused-argument, redefined-builtin
|
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
|
self.load_state_dict(state["model"])
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
|
|
def get_criterion(self):
|
|
from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel
|
|
|
|
return FastPitchLoss(self.config)
|
|
|
|
|
|
def average_pitch(pitch, durs):
|
|
durs_cums_ends = torch.cumsum(durs, dim=1).long()
|
|
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
|
|
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
|
|
pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0))
|
|
|
|
bs, l = durs_cums_ends.size()
|
|
n_formants = pitch.size(1)
|
|
dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
|
|
dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)
|
|
|
|
pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float()
|
|
pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float()
|
|
|
|
pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems)
|
|
return pitch_avg
|