mirror of https://github.com/coqui-ai/TTS.git
feat: add openvoice vc model
This commit is contained in:
parent
1a21853b90
commit
fce3137e0d
|
@ -121,6 +121,7 @@ repository are also still a useful source of information.
|
||||||
|
|
||||||
### Voice Conversion
|
### Voice Conversion
|
||||||
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
|
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
|
||||||
|
- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479)
|
||||||
|
|
||||||
You can also help us implement more models.
|
You can also help us implement more models.
|
||||||
|
|
||||||
|
|
|
@ -155,8 +155,10 @@ class TTS(nn.Module):
|
||||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
|
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
|
||||||
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
|
self.voice_converter = Synthesizer(
|
||||||
|
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
|
||||||
|
)
|
||||||
|
|
||||||
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
||||||
"""Load one of 🐸TTS models by name.
|
"""Load one of 🐸TTS models by name.
|
||||||
|
|
|
@ -424,7 +424,7 @@ class ModelManager(object):
|
||||||
model_file = None
|
model_file = None
|
||||||
config_file = None
|
config_file = None
|
||||||
for file_name in os.listdir(output_path):
|
for file_name in os.listdir(output_path):
|
||||||
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
|
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
|
||||||
model_file = os.path.join(output_path, file_name)
|
model_file = os.path.join(output_path, file_name)
|
||||||
elif file_name == "config.json":
|
elif file_name == "config.json":
|
||||||
config_file = os.path.join(output_path, file_name)
|
config_file = os.path.join(output_path, file_name)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -15,7 +16,9 @@ from TTS.tts.models.vits import Vits
|
||||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.audio.numpy_transforms import save_wav
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
|
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
|
||||||
from TTS.vc.models import setup_model as setup_vc_model
|
from TTS.vc.models import setup_model as setup_vc_model
|
||||||
|
from TTS.vc.models.openvoice import OpenVoice
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
|
|
||||||
|
@ -97,7 +100,7 @@ class Synthesizer(nn.Module):
|
||||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||||
|
|
||||||
if vc_checkpoint:
|
if vc_checkpoint and model_dir is None:
|
||||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
|
@ -105,6 +108,9 @@ class Synthesizer(nn.Module):
|
||||||
if "fairseq" in model_dir:
|
if "fairseq" in model_dir:
|
||||||
self._load_fairseq_from_dir(model_dir, use_cuda)
|
self._load_fairseq_from_dir(model_dir, use_cuda)
|
||||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||||
|
elif "openvoice" in model_dir:
|
||||||
|
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
|
||||||
|
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||||
else:
|
else:
|
||||||
self._load_tts_from_dir(model_dir, use_cuda)
|
self._load_tts_from_dir(model_dir, use_cuda)
|
||||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||||
|
@ -153,6 +159,19 @@ class Synthesizer(nn.Module):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
|
def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
|
||||||
|
"""Load the OpenVoice model from a directory.
|
||||||
|
|
||||||
|
We assume the model knows how to load itself from the directory and
|
||||||
|
there is a config.json file in the directory.
|
||||||
|
"""
|
||||||
|
self.vc_config = OpenVoiceConfig()
|
||||||
|
self.vc_model = OpenVoice.init_from_config(self.vc_config)
|
||||||
|
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
|
||||||
|
self.vc_config = self.vc_model.config
|
||||||
|
if use_cuda:
|
||||||
|
self.vc_model.cuda()
|
||||||
|
|
||||||
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model from a directory.
|
"""Load the TTS model from a directory.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,320 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping, Optional, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
|
from TTS.tts.layers.vits.networks import PosteriorEncoder
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.utils.audio.torch_transforms import wav_to_spec
|
||||||
|
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
|
||||||
|
from TTS.vc.models.base_vc import BaseVC
|
||||||
|
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceEncoder(nn.Module):
|
||||||
|
"""NN module creating a fixed size prosody embedding from a spectrogram.
|
||||||
|
|
||||||
|
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||||
|
outputs: [batch_size, embedding_dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.spec_channels = spec_channels
|
||||||
|
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||||
|
K = len(ref_enc_filters)
|
||||||
|
filters = [1] + ref_enc_filters
|
||||||
|
convs = [
|
||||||
|
torch.nn.utils.parametrizations.weight_norm(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=filters[i],
|
||||||
|
out_channels=filters[i + 1],
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(2, 2),
|
||||||
|
padding=(1, 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(K)
|
||||||
|
]
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
|
||||||
|
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=ref_enc_filters[-1] * out_channels,
|
||||||
|
hidden_size=256 // 2,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.proj = nn.Linear(128, embedding_dim)
|
||||||
|
self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
N = inputs.size(0)
|
||||||
|
|
||||||
|
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
||||||
|
if self.layernorm is not None:
|
||||||
|
out = self.layernorm(out)
|
||||||
|
|
||||||
|
for conv in self.convs:
|
||||||
|
out = conv(out)
|
||||||
|
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||||
|
|
||||||
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||||
|
T = out.size(1)
|
||||||
|
N = out.size(0)
|
||||||
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||||
|
|
||||||
|
self.gru.flatten_parameters()
|
||||||
|
_memory, out = self.gru(out) # out --- [1, N, 128]
|
||||||
|
|
||||||
|
return self.proj(out.squeeze(0))
|
||||||
|
|
||||||
|
def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
|
||||||
|
for _ in range(n_convs):
|
||||||
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
class OpenVoice(BaseVC):
|
||||||
|
"""
|
||||||
|
OpenVoice voice conversion model (inference only).
|
||||||
|
|
||||||
|
Source: https://github.com/myshell-ai/OpenVoice
|
||||||
|
Paper: https://arxiv.org/abs/2312.01479
|
||||||
|
|
||||||
|
Paper abstract:
|
||||||
|
We introduce OpenVoice, a versatile voice cloning approach that requires
|
||||||
|
only a short audio clip from the reference speaker to replicate their voice and
|
||||||
|
generate speech in multiple languages. OpenVoice represents a significant
|
||||||
|
advancement in addressing the following open challenges in the field: 1)
|
||||||
|
Flexible Voice Style Control. OpenVoice enables granular control over voice
|
||||||
|
styles, including emotion, accent, rhythm, pauses, and intonation, in addition
|
||||||
|
to replicating the tone color of the reference speaker. The voice styles are not
|
||||||
|
directly copied from and constrained by the style of the reference speaker.
|
||||||
|
Previous approaches lacked the ability to flexibly manipulate voice styles after
|
||||||
|
cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot
|
||||||
|
cross-lingual voice cloning for languages not included in the massive-speaker
|
||||||
|
training set. Unlike previous approaches, which typically require extensive
|
||||||
|
massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can
|
||||||
|
clone voices into a new language without any massive-speaker training data for
|
||||||
|
that language. OpenVoice is also computationally efficient, costing tens of
|
||||||
|
times less than commercially available APIs that offer even inferior
|
||||||
|
performance. To foster further research in the field, we have made the source
|
||||||
|
code and trained model publicly accessible. We also provide qualitative results
|
||||||
|
in our demo website. Prior to its public release, our internal version of
|
||||||
|
OpenVoice was used tens of millions of times by users worldwide between May and
|
||||||
|
October 2023, serving as the backend of MyShell.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None:
|
||||||
|
super().__init__(config, None, speaker_manager, None)
|
||||||
|
|
||||||
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
|
self.zero_g = self.args.zero_g
|
||||||
|
self.inter_channels = self.args.inter_channels
|
||||||
|
self.hidden_channels = self.args.hidden_channels
|
||||||
|
self.filter_channels = self.args.filter_channels
|
||||||
|
self.n_heads = self.args.n_heads
|
||||||
|
self.n_layers = self.args.n_layers
|
||||||
|
self.kernel_size = self.args.kernel_size
|
||||||
|
self.p_dropout = self.args.p_dropout
|
||||||
|
self.resblock = self.args.resblock
|
||||||
|
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
|
||||||
|
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
|
||||||
|
self.upsample_rates = self.args.upsample_rates
|
||||||
|
self.upsample_initial_channel = self.args.upsample_initial_channel
|
||||||
|
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
|
||||||
|
self.n_layers_q = self.args.n_layers_q
|
||||||
|
self.use_spectral_norm = self.args.use_spectral_norm
|
||||||
|
self.gin_channels = self.args.gin_channels
|
||||||
|
self.tau = self.args.tau
|
||||||
|
|
||||||
|
self.spec_channels = config.audio.fft_size // 2 + 1
|
||||||
|
|
||||||
|
self.dec = Generator(
|
||||||
|
self.inter_channels,
|
||||||
|
self.resblock,
|
||||||
|
self.resblock_kernel_sizes,
|
||||||
|
self.resblock_dilation_sizes,
|
||||||
|
self.upsample_rates,
|
||||||
|
self.upsample_initial_channel,
|
||||||
|
self.upsample_kernel_sizes,
|
||||||
|
gin_channels=self.gin_channels,
|
||||||
|
)
|
||||||
|
self.enc_q = PosteriorEncoder(
|
||||||
|
self.spec_channels,
|
||||||
|
self.inter_channels,
|
||||||
|
self.hidden_channels,
|
||||||
|
kernel_size=5,
|
||||||
|
dilation_rate=1,
|
||||||
|
num_layers=16,
|
||||||
|
cond_channels=self.gin_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.flow = ResidualCouplingBlock(
|
||||||
|
self.inter_channels,
|
||||||
|
self.hidden_channels,
|
||||||
|
kernel_size=5,
|
||||||
|
dilation_rate=1,
|
||||||
|
n_layers=4,
|
||||||
|
gin_channels=self.gin_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: OpenVoiceConfig) -> "OpenVoice":
|
||||||
|
return OpenVoice(config)
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
|
||||||
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
data (list, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.num_spks = config.num_speakers
|
||||||
|
if self.speaker_manager:
|
||||||
|
self.num_spks = self.speaker_manager.num_speakers
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self,
|
||||||
|
config: OpenVoiceConfig,
|
||||||
|
checkpoint_path: Union[str, os.PathLike[Any]],
|
||||||
|
eval: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
|
cache: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Map from OpenVoice's config structure."""
|
||||||
|
config_path = Path(checkpoint_path).parent / "config.json"
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
config_org = json.load(f)
|
||||||
|
self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"]
|
||||||
|
self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"]
|
||||||
|
self.config.audio.fft_size = config_org["data"]["filter_length"]
|
||||||
|
self.config.audio.hop_length = config_org["data"]["hop_length"]
|
||||||
|
self.config.audio.win_length = config_org["data"]["win_length"]
|
||||||
|
state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache)
|
||||||
|
self.load_state_dict(state["model"], strict=strict)
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def forward(self) -> None: ...
|
||||||
|
def train_step(self) -> None: ...
|
||||||
|
def eval_step(self) -> None: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor:
|
||||||
|
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
|
||||||
|
return aux_input["x_lengths"]
|
||||||
|
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None},
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Inference pass of the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
|
||||||
|
x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
|
||||||
|
g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
|
||||||
|
g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||||
|
x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len).
|
||||||
|
(z, z_p, z_hat): A tuple of latent variables.
|
||||||
|
"""
|
||||||
|
x_lengths = self._set_x_lengths(x, aux_input)
|
||||||
|
if "g_src" in aux_input and aux_input["g_src"] is not None:
|
||||||
|
g_src = aux_input["g_src"]
|
||||||
|
else:
|
||||||
|
raise ValueError("aux_input must define g_src")
|
||||||
|
if "g_tgt" in aux_input and aux_input["g_tgt"] is not None:
|
||||||
|
g_tgt = aux_input["g_tgt"]
|
||||||
|
else:
|
||||||
|
raise ValueError("aux_input must define g_tgt")
|
||||||
|
z, _m_q, _logs_q, y_mask = self.enc_q(
|
||||||
|
x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau
|
||||||
|
)
|
||||||
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
|
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
|
||||||
|
return {
|
||||||
|
"model_outputs": o_hat,
|
||||||
|
"y_mask": y_mask,
|
||||||
|
"z": z,
|
||||||
|
"z_p": z_p,
|
||||||
|
"z_hat": z_hat,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor:
|
||||||
|
"""Read and format the input audio."""
|
||||||
|
if isinstance(wav, str):
|
||||||
|
out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0])
|
||||||
|
elif isinstance(wav, np.ndarray):
|
||||||
|
out = torch.from_numpy(wav)
|
||||||
|
elif isinstance(wav, list):
|
||||||
|
out = torch.from_numpy(np.array(wav))
|
||||||
|
else:
|
||||||
|
out = wav
|
||||||
|
return out.to(self.device).float()
|
||||||
|
|
||||||
|
def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
audio_ref = self.load_audio(audio)
|
||||||
|
y = torch.FloatTensor(audio_ref)
|
||||||
|
y = y.to(self.device)
|
||||||
|
y = y.unsqueeze(0)
|
||||||
|
spec = wav_to_spec(
|
||||||
|
y,
|
||||||
|
n_fft=self.config.audio.fft_size,
|
||||||
|
hop_length=self.config.audio.hop_length,
|
||||||
|
win_length=self.config.audio.win_length,
|
||||||
|
center=False,
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1)
|
||||||
|
|
||||||
|
return g, spec
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]:
|
||||||
|
"""
|
||||||
|
Voice conversion pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src (str or torch.Tensor): Source utterance.
|
||||||
|
tgt (str or torch.Tensor): Target utterance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output numpy array.
|
||||||
|
"""
|
||||||
|
src_se, src_spec = self.extract_se(src)
|
||||||
|
tgt_se, _ = self.extract_se(tgt)
|
||||||
|
|
||||||
|
aux_input = {"g_src": src_se, "g_tgt": tgt_se}
|
||||||
|
audio = self.inference(src_spec, aux_input)
|
||||||
|
return audio["model_outputs"][0, 0].data.cpu().float().numpy()
|
|
@ -1,134 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder
|
|
||||||
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
|
|
||||||
|
|
||||||
|
|
||||||
class ReferenceEncoder(nn.Module):
|
|
||||||
"""
|
|
||||||
inputs --- [N, Ty/r, n_mels*r] mels
|
|
||||||
outputs --- [N, ref_enc_gru_size]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
|
|
||||||
super().__init__()
|
|
||||||
self.spec_channels = spec_channels
|
|
||||||
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
|
||||||
K = len(ref_enc_filters)
|
|
||||||
filters = [1] + ref_enc_filters
|
|
||||||
convs = [
|
|
||||||
torch.nn.utils.parametrizations.weight_norm(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=filters[i],
|
|
||||||
out_channels=filters[i + 1],
|
|
||||||
kernel_size=(3, 3),
|
|
||||||
stride=(2, 2),
|
|
||||||
padding=(1, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for i in range(K)
|
|
||||||
]
|
|
||||||
self.convs = nn.ModuleList(convs)
|
|
||||||
|
|
||||||
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
|
||||||
self.gru = nn.GRU(
|
|
||||||
input_size=ref_enc_filters[-1] * out_channels,
|
|
||||||
hidden_size=256 // 2,
|
|
||||||
batch_first=True,
|
|
||||||
)
|
|
||||||
self.proj = nn.Linear(128, gin_channels)
|
|
||||||
if layernorm:
|
|
||||||
self.layernorm = nn.LayerNorm(self.spec_channels)
|
|
||||||
else:
|
|
||||||
self.layernorm = None
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
N = inputs.size(0)
|
|
||||||
|
|
||||||
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
|
||||||
if self.layernorm is not None:
|
|
||||||
out = self.layernorm(out)
|
|
||||||
|
|
||||||
for conv in self.convs:
|
|
||||||
out = conv(out)
|
|
||||||
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
|
||||||
|
|
||||||
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
|
||||||
T = out.size(1)
|
|
||||||
N = out.size(0)
|
|
||||||
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
|
||||||
|
|
||||||
self.gru.flatten_parameters()
|
|
||||||
_memory, out = self.gru(out) # out --- [1, N, 128]
|
|
||||||
|
|
||||||
return self.proj(out.squeeze(0))
|
|
||||||
|
|
||||||
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
|
||||||
for _ in range(n_convs):
|
|
||||||
L = (L - kernel_size + 2 * pad) // stride + 1
|
|
||||||
return L
|
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrn(nn.Module):
|
|
||||||
"""
|
|
||||||
Synthesizer for Training
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spec_channels,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
resblock,
|
|
||||||
resblock_kernel_sizes,
|
|
||||||
resblock_dilation_sizes,
|
|
||||||
upsample_rates,
|
|
||||||
upsample_initial_channel,
|
|
||||||
upsample_kernel_sizes,
|
|
||||||
n_speakers=0,
|
|
||||||
gin_channels=256,
|
|
||||||
zero_g=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dec = Generator(
|
|
||||||
inter_channels,
|
|
||||||
resblock,
|
|
||||||
resblock_kernel_sizes,
|
|
||||||
resblock_dilation_sizes,
|
|
||||||
upsample_rates,
|
|
||||||
upsample_initial_channel,
|
|
||||||
upsample_kernel_sizes,
|
|
||||||
gin_channels=gin_channels,
|
|
||||||
)
|
|
||||||
self.enc_q = PosteriorEncoder(
|
|
||||||
spec_channels,
|
|
||||||
inter_channels,
|
|
||||||
hidden_channels,
|
|
||||||
5,
|
|
||||||
1,
|
|
||||||
16,
|
|
||||||
cond_channels=gin_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
|
||||||
|
|
||||||
self.n_speakers = n_speakers
|
|
||||||
if n_speakers != 0:
|
|
||||||
raise ValueError("OpenVoice inference only supports n_speaker==0")
|
|
||||||
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
|
||||||
self.zero_g = zero_g
|
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
|
|
||||||
g_src = sid_src
|
|
||||||
g_tgt = sid_tgt
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(
|
|
||||||
y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau
|
|
||||||
)
|
|
||||||
z_p = self.flow(z, y_mask, g=g_src)
|
|
||||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
|
||||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
|
|
||||||
return o_hat, y_mask, (z, z_p, z_hat)
|
|
Loading…
Reference in New Issue