mirror of https://github.com/coqui-ai/TTS.git
135 lines
4.1 KiB
Python
135 lines
4.1 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from TTS.tts.layers.vits.networks import PosteriorEncoder
|
|
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
|
|
|
|
|
|
class ReferenceEncoder(nn.Module):
|
|
"""
|
|
inputs --- [N, Ty/r, n_mels*r] mels
|
|
outputs --- [N, ref_enc_gru_size]
|
|
"""
|
|
|
|
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
|
|
super().__init__()
|
|
self.spec_channels = spec_channels
|
|
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
|
K = len(ref_enc_filters)
|
|
filters = [1] + ref_enc_filters
|
|
convs = [
|
|
torch.nn.utils.parametrizations.weight_norm(
|
|
nn.Conv2d(
|
|
in_channels=filters[i],
|
|
out_channels=filters[i + 1],
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2),
|
|
padding=(1, 1),
|
|
)
|
|
)
|
|
for i in range(K)
|
|
]
|
|
self.convs = nn.ModuleList(convs)
|
|
|
|
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
|
self.gru = nn.GRU(
|
|
input_size=ref_enc_filters[-1] * out_channels,
|
|
hidden_size=256 // 2,
|
|
batch_first=True,
|
|
)
|
|
self.proj = nn.Linear(128, gin_channels)
|
|
if layernorm:
|
|
self.layernorm = nn.LayerNorm(self.spec_channels)
|
|
else:
|
|
self.layernorm = None
|
|
|
|
def forward(self, inputs):
|
|
N = inputs.size(0)
|
|
|
|
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
|
if self.layernorm is not None:
|
|
out = self.layernorm(out)
|
|
|
|
for conv in self.convs:
|
|
out = conv(out)
|
|
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
|
|
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
|
T = out.size(1)
|
|
N = out.size(0)
|
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
|
|
|
self.gru.flatten_parameters()
|
|
_memory, out = self.gru(out) # out --- [1, N, 128]
|
|
|
|
return self.proj(out.squeeze(0))
|
|
|
|
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
|
for _ in range(n_convs):
|
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
|
return L
|
|
|
|
|
|
class SynthesizerTrn(nn.Module):
|
|
"""
|
|
Synthesizer for Training
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
spec_channels,
|
|
inter_channels,
|
|
hidden_channels,
|
|
resblock,
|
|
resblock_kernel_sizes,
|
|
resblock_dilation_sizes,
|
|
upsample_rates,
|
|
upsample_initial_channel,
|
|
upsample_kernel_sizes,
|
|
n_speakers=0,
|
|
gin_channels=256,
|
|
zero_g=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dec = Generator(
|
|
inter_channels,
|
|
resblock,
|
|
resblock_kernel_sizes,
|
|
resblock_dilation_sizes,
|
|
upsample_rates,
|
|
upsample_initial_channel,
|
|
upsample_kernel_sizes,
|
|
gin_channels=gin_channels,
|
|
)
|
|
self.enc_q = PosteriorEncoder(
|
|
spec_channels,
|
|
inter_channels,
|
|
hidden_channels,
|
|
5,
|
|
1,
|
|
16,
|
|
cond_channels=gin_channels,
|
|
)
|
|
|
|
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
|
|
|
self.n_speakers = n_speakers
|
|
if n_speakers != 0:
|
|
raise ValueError("OpenVoice inference only supports n_speaker==0")
|
|
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
|
self.zero_g = zero_g
|
|
|
|
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
|
|
g_src = sid_src
|
|
g_tgt = sid_tgt
|
|
z, m_q, logs_q, y_mask = self.enc_q(
|
|
y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau
|
|
)
|
|
z_p = self.flow(z, y_mask, g=g_src)
|
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
|
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
|
|
return o_hat, y_mask, (z, z_p, z_hat)
|