import torch from torch import nn from torch.nn import functional as F from TTS.tts.layers.vits.networks import PosteriorEncoder from TTS.vc.models.freevc import Generator, ResidualCouplingBlock class ReferenceEncoder(nn.Module): """ inputs --- [N, Ty/r, n_mels*r] mels outputs --- [N, ref_enc_gru_size] """ def __init__(self, spec_channels, gin_channels=0, layernorm=True): super().__init__() self.spec_channels = spec_channels ref_enc_filters = [32, 32, 64, 64, 128, 128] K = len(ref_enc_filters) filters = [1] + ref_enc_filters convs = [ torch.nn.utils.parametrizations.weight_norm( nn.Conv2d( in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ) ) for i in range(K) ] self.convs = nn.ModuleList(convs) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) self.gru = nn.GRU( input_size=ref_enc_filters[-1] * out_channels, hidden_size=256 // 2, batch_first=True, ) self.proj = nn.Linear(128, gin_channels) if layernorm: self.layernorm = nn.LayerNorm(self.spec_channels) else: self.layernorm = None def forward(self, inputs): N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] if self.layernorm is not None: out = self.layernorm(out) for conv in self.convs: out = conv(out) out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] T = out.size(1) N = out.size(0) out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] self.gru.flatten_parameters() _memory, out = self.gru(out) # out --- [1, N, 128] return self.proj(out.squeeze(0)) def calculate_channels(self, L, kernel_size, stride, pad, n_convs): for _ in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L class SynthesizerTrn(nn.Module): """ Synthesizer for Training """ def __init__( self, spec_channels, inter_channels, hidden_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, gin_channels=256, zero_g=False, **kwargs, ): super().__init__() self.dec = Generator( inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, cond_channels=gin_channels, ) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.n_speakers = n_speakers if n_speakers != 0: raise ValueError("OpenVoice inference only supports n_speaker==0") self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) self.zero_g = zero_g def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): g_src = sid_src g_tgt = sid_tgt z, m_q, logs_q, y_mask = self.enc_q( y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau ) z_p = self.flow(z, y_mask, g=g_src) z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) return o_hat, y_mask, (z, z_p, z_hat)