coqui-tts/TTS/vc/modules/openvoice/models.py

135 lines
4.1 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.vits.networks import PosteriorEncoder
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
class ReferenceEncoder(nn.Module):
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
"""
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
torch.nn.utils.parametrizations.weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
if layernorm:
self.layernorm = nn.LayerNorm(self.spec_channels)
else:
self.layernorm = None
def forward(self, inputs):
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None:
out = self.layernorm(out)
for conv in self.convs:
out = conv(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
_memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
spec_channels,
inter_channels,
hidden_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=256,
zero_g=False,
**kwargs,
):
super().__init__()
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
cond_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.n_speakers = n_speakers
if n_speakers != 0:
raise ValueError("OpenVoice inference only supports n_speaker==0")
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
self.zero_g = zero_g
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(
y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau
)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return o_hat, y_mask, (z, z_p, z_hat)