mirror of https://github.com/coqui-ai/TTS.git
refactor(delightful_tts): remove unused classes
This commit is contained in:
parent
7cdfde226b
commit
6f25c2b904
|
@ -1,20 +1,14 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d, calc_same_padding
|
||||
from TTS.tts.layers.delightful_tts.networks import GLUActivation
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
class Conformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -322,7 +316,7 @@ class ConformerMultiHeadedSelfAttention(nn.Module):
|
|||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable
|
||||
encoding = encoding[:, : key.shape[1]]
|
||||
encoding = encoding.repeat(batch_size, 1, 1)
|
||||
|
@ -378,7 +372,7 @@ class RelativeMultiHeadAttention(nn.Module):
|
|||
value: torch.Tensor,
|
||||
pos_embedding: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = query.shape[0]
|
||||
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
||||
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
||||
|
@ -411,40 +405,3 @@ class RelativeMultiHeadAttention(nn.Module):
|
|||
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
||||
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
||||
return pos_score
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
split_size = self.num_units // self.num_heads
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
return out
|
||||
|
|
|
@ -3,9 +3,6 @@ from typing import Tuple
|
|||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
|
@ -530,142 +527,3 @@ class CoordConv2d(nn.modules.conv.Conv2d):
|
|||
x = self.addcoords(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
stride,
|
||||
dilations=[1, 3, 9, 27],
|
||||
lReLU_slope=0.2,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = len(dilations)
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=len(dilations),
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
||||
)
|
||||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
2 * stride,
|
||||
stride=stride,
|
||||
padding=stride // 2 + stride % 2,
|
||||
output_padding=stride % 2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
for dilation in dilations:
|
||||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
conv_kernel_size,
|
||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
||||
dilation=dilation,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||
|
||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
for i, conv in enumerate(self.conv_blocks):
|
||||
output = conv(x) # (B, c_g, stride * L')
|
||||
|
||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||
|
||||
output = self.location_variable_convolution(
|
||||
output, k, b, hop_size=self.cond_hop_length
|
||||
) # (B, 2 * c_g, stride * L'): LVC
|
||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||
output[:, in_channels:, :]
|
||||
) # (B, c_g, stride * L'): GAU
|
||||
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||
for block in self.conv_blocks:
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
|
|
|
@ -1,128 +0,0 @@
|
|||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
class KernelPredictor(nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions
|
||||
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int): number of layers
|
||||
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
for _ in range(3):
|
||||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
c = self.input_conv(c)
|
||||
for residual_conv in self.residual_convs:
|
||||
residual_conv.to(c.device)
|
||||
c = c + residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
kernels = k.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length,
|
||||
)
|
||||
bias = b.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length,
|
||||
)
|
||||
|
||||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||
parametrize.remove_parametrizations(self.bias_conv, "weight")
|
||||
for block in self.residual_convs:
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
parametrize.remove_parametrizations(block[3], "weight")
|
|
@ -117,7 +117,7 @@ class MultiHeadAttention(nn.Module):
|
|||
out --- [N, T_q, num_units]
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
|
@ -127,7 +127,7 @@ class MultiHeadAttention(nn.Module):
|
|||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
||||
queries = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
|
@ -137,13 +137,11 @@ class MultiHeadAttention(nn.Module):
|
|||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k**0.5))
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
||||
return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
|
|
@ -8,11 +8,9 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.io import load_fsspec
|
||||
|
@ -24,8 +22,10 @@ from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
|||
from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||
from TTS.tts.models.vits import load_audio
|
||||
from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import embedding_to_torch, id_to_torch, numpy_to_torch
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram
|
||||
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||
|
@ -40,103 +40,10 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return aux_id.cuda()
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).float()
|
||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||
if cuda:
|
||||
return d_vector.cuda()
|
||||
return d_vector
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
||||
if cuda:
|
||||
return tensor.cuda()
|
||||
return tensor
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
return mask
|
||||
|
||||
|
||||
def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
|
||||
out_list = torch.jit.annotate(List[torch.Tensor], [])
|
||||
for batch in input_ele:
|
||||
if len(batch.shape) == 1:
|
||||
one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0)
|
||||
else:
|
||||
one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0)
|
||||
out_list.append(one_batch_padded)
|
||||
out_padded = torch.stack(out_list)
|
||||
return out_padded
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
return torch.ceil(lens / stride).int()
|
||||
|
||||
|
||||
def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor:
|
||||
assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..."
|
||||
return torch.randn(shape) * np.sqrt(2 / shape[1])
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
# check if the current module has reset_parameters and if it is reset the weight
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
|
||||
|
||||
def get_module_weights_sum(mdl: nn.Module):
|
||||
dict_sums = {}
|
||||
for name, w in mdl.named_parameters():
|
||||
if "weight" in name:
|
||||
value = w.data.sum().item()
|
||||
dict_sums[name] = value
|
||||
return dict_sums
|
||||
|
||||
|
||||
def load_audio(file_path: str):
|
||||
"""Load the audio file normalized in [-1, 1]
|
||||
|
||||
Return Shapes:
|
||||
- x: :math:`[1, T]`
|
||||
"""
|
||||
x, sr = torchaudio.load(
|
||||
file_path,
|
||||
)
|
||||
assert (x > 1).sum() + (x < -1).sum() == 0
|
||||
return x, sr
|
||||
|
||||
|
||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
y = y.squeeze(1)
|
||||
|
||||
|
@ -1179,7 +1086,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
# TODO: add cloning support with ref_waveform
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
|
@ -1193,14 +1100,14 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
if isinstance(speaker_id, str) and self.args.use_speaker_embedding:
|
||||
# get the speaker id for the speaker embedding layer
|
||||
_speaker_id = self.speaker_manager.name_to_id[speaker_id]
|
||||
_speaker_id = id_to_torch(_speaker_id, cuda=is_cuda)
|
||||
_speaker_id = id_to_torch(_speaker_id, device=device)
|
||||
|
||||
if speaker_id is not None and self.args.use_d_vector_file:
|
||||
# get the average d_vector for the speaker
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False)
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
|
@ -1223,7 +1130,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
return return_dict
|
||||
|
||||
def synthesize_with_gl(self, text: str, speaker_id, d_vector):
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
|
@ -1232,12 +1139,12 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
)
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, device=device)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
from typing import Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def numpy_to_torch(
|
||||
np_array: np.ndarray, dtype: torch.dtype, device: Union[str, torch.device] = "cpu"
|
||||
) -> Optional[torch.Tensor]:
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||
return tensor
|
||||
return torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
|
||||
|
@ -76,18 +75,14 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def id_to_torch(aux_id, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id).to(device)
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def embedding_to_torch(d_vector, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||
|
|
Loading…
Reference in New Issue