feat: add openvoice vc model

This commit is contained in:
Enno Hermann 2024-06-25 23:01:47 +02:00
parent 1a21853b90
commit fce3137e0d
7 changed files with 346 additions and 138 deletions

View File

@ -121,6 +121,7 @@ repository are also still a useful source of information.
### Voice Conversion
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479)
You can also help us implement more models.

View File

@ -155,8 +155,10 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of 🐸TTS models by name.

View File

@ -424,7 +424,7 @@ class ModelManager(object):
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)

View File

@ -1,6 +1,7 @@
import logging
import os
import time
from pathlib import Path
from typing import List
import numpy as np
@ -15,7 +16,9 @@ from TTS.tts.models.vits import Vits
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models import setup_model as setup_vc_model
from TTS.vc.models.openvoice import OpenVoice
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@ -97,7 +100,7 @@ class Synthesizer(nn.Module):
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint:
if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
@ -105,6 +108,9 @@ class Synthesizer(nn.Module):
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@ -153,6 +159,19 @@ class Synthesizer(nn.Module):
if use_cuda:
self.tts_model.cuda()
def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
"""Load the OpenVoice model from a directory.
We assume the model knows how to load itself from the directory and
there is a config.json file in the directory.
"""
self.vc_config = OpenVoiceConfig()
self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config
if use_cuda:
self.vc_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.

320
TTS/vc/models/openvoice.py Normal file
View File

@ -0,0 +1,320 @@
import json
import logging
import os
from pathlib import Path
from typing import Any, Mapping, Optional, Union
import librosa
import numpy as np
import numpy.typing as npt
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import functional as F
from trainer.io import load_fsspec
from TTS.tts.layers.vits.networks import PosteriorEncoder
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio.torch_transforms import wav_to_spec
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models.base_vc import BaseVC
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
logger = logging.getLogger(__name__)
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None:
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
torch.nn.utils.parametrizations.weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, embedding_dim)
self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None:
out = self.layernorm(out)
for conv in self.convs:
out = conv(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
_memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class OpenVoice(BaseVC):
"""
OpenVoice voice conversion model (inference only).
Source: https://github.com/myshell-ai/OpenVoice
Paper: https://arxiv.org/abs/2312.01479
Paper abstract:
We introduce OpenVoice, a versatile voice cloning approach that requires
only a short audio clip from the reference speaker to replicate their voice and
generate speech in multiple languages. OpenVoice represents a significant
advancement in addressing the following open challenges in the field: 1)
Flexible Voice Style Control. OpenVoice enables granular control over voice
styles, including emotion, accent, rhythm, pauses, and intonation, in addition
to replicating the tone color of the reference speaker. The voice styles are not
directly copied from and constrained by the style of the reference speaker.
Previous approaches lacked the ability to flexibly manipulate voice styles after
cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot
cross-lingual voice cloning for languages not included in the massive-speaker
training set. Unlike previous approaches, which typically require extensive
massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can
clone voices into a new language without any massive-speaker training data for
that language. OpenVoice is also computationally efficient, costing tens of
times less than commercially available APIs that offer even inferior
performance. To foster further research in the field, we have made the source
code and trained model publicly accessible. We also provide qualitative results
in our demo website. Prior to its public release, our internal version of
OpenVoice was used tens of millions of times by users worldwide between May and
October 2023, serving as the backend of MyShell.
"""
def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None:
super().__init__(config, None, speaker_manager, None)
self.init_multispeaker(config)
self.zero_g = self.args.zero_g
self.inter_channels = self.args.inter_channels
self.hidden_channels = self.args.hidden_channels
self.filter_channels = self.args.filter_channels
self.n_heads = self.args.n_heads
self.n_layers = self.args.n_layers
self.kernel_size = self.args.kernel_size
self.p_dropout = self.args.p_dropout
self.resblock = self.args.resblock
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
self.upsample_rates = self.args.upsample_rates
self.upsample_initial_channel = self.args.upsample_initial_channel
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
self.n_layers_q = self.args.n_layers_q
self.use_spectral_norm = self.args.use_spectral_norm
self.gin_channels = self.args.gin_channels
self.tau = self.args.tau
self.spec_channels = config.audio.fft_size // 2 + 1
self.dec = Generator(
self.inter_channels,
self.resblock,
self.resblock_kernel_sizes,
self.resblock_dilation_sizes,
self.upsample_rates,
self.upsample_initial_channel,
self.upsample_kernel_sizes,
gin_channels=self.gin_channels,
)
self.enc_q = PosteriorEncoder(
self.spec_channels,
self.inter_channels,
self.hidden_channels,
kernel_size=5,
dilation_rate=1,
num_layers=16,
cond_channels=self.gin_channels,
)
self.flow = ResidualCouplingBlock(
self.inter_channels,
self.hidden_channels,
kernel_size=5,
dilation_rate=1,
n_layers=4,
gin_channels=self.gin_channels,
)
self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@staticmethod
def init_from_config(config: OpenVoiceConfig) -> "OpenVoice":
return OpenVoice(config)
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
Args:
config (Coqpit): Model configuration.
data (list, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.num_spks = config.num_speakers
if self.speaker_manager:
self.num_spks = self.speaker_manager.num_speakers
def load_checkpoint(
self,
config: OpenVoiceConfig,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None:
"""Map from OpenVoice's config structure."""
config_path = Path(checkpoint_path).parent / "config.json"
with open(config_path, encoding="utf-8") as f:
config_org = json.load(f)
self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"]
self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"]
self.config.audio.fft_size = config_org["data"]["filter_length"]
self.config.audio.hop_length = config_org["data"]["hop_length"]
self.config.audio.win_length = config_org["data"]["win_length"]
state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"], strict=strict)
if eval:
self.eval()
def forward(self) -> None: ...
def train_step(self) -> None: ...
def eval_step(self) -> None: ...
@staticmethod
def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor:
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
@torch.no_grad()
def inference(
self,
x: torch.Tensor,
aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None},
) -> dict[str, torch.Tensor]:
"""
Inference pass of the model
Args:
x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
Returns:
o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim).
x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len).
(z, z_p, z_hat): A tuple of latent variables.
"""
x_lengths = self._set_x_lengths(x, aux_input)
if "g_src" in aux_input and aux_input["g_src"] is not None:
g_src = aux_input["g_src"]
else:
raise ValueError("aux_input must define g_src")
if "g_tgt" in aux_input and aux_input["g_tgt"] is not None:
g_tgt = aux_input["g_tgt"]
else:
raise ValueError("aux_input must define g_tgt")
z, _m_q, _logs_q, y_mask = self.enc_q(
x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau
)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return {
"model_outputs": o_hat,
"y_mask": y_mask,
"z": z,
"z_p": z_p,
"z_hat": z_hat,
}
def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor:
"""Read and format the input audio."""
if isinstance(wav, str):
out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0])
elif isinstance(wav, np.ndarray):
out = torch.from_numpy(wav)
elif isinstance(wav, list):
out = torch.from_numpy(np.array(wav))
else:
out = wav
return out.to(self.device).float()
def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
audio_ref = self.load_audio(audio)
y = torch.FloatTensor(audio_ref)
y = y.to(self.device)
y = y.unsqueeze(0)
spec = wav_to_spec(
y,
n_fft=self.config.audio.fft_size,
hop_length=self.config.audio.hop_length,
win_length=self.config.audio.win_length,
center=False,
).to(self.device)
with torch.no_grad():
g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1)
return g, spec
@torch.inference_mode()
def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]:
"""
Voice conversion pass of the model.
Args:
src (str or torch.Tensor): Source utterance.
tgt (str or torch.Tensor): Target utterance.
Returns:
Output numpy array.
"""
src_se, src_spec = self.extract_se(src)
tgt_se, _ = self.extract_se(tgt)
aux_input = {"g_src": src_se, "g_tgt": tgt_se}
audio = self.inference(src_spec, aux_input)
return audio["model_outputs"][0, 0].data.cpu().float().numpy()

View File

@ -1,134 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.vits.networks import PosteriorEncoder
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
class ReferenceEncoder(nn.Module):
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
"""
def __init__(self, spec_channels, gin_channels=0, layernorm=True):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
torch.nn.utils.parametrizations.weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
if layernorm:
self.layernorm = nn.LayerNorm(self.spec_channels)
else:
self.layernorm = None
def forward(self, inputs):
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None:
out = self.layernorm(out)
for conv in self.convs:
out = conv(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
_memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
spec_channels,
inter_channels,
hidden_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=256,
zero_g=False,
**kwargs,
):
super().__init__()
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
cond_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.n_speakers = n_speakers
if n_speakers != 0:
raise ValueError("OpenVoice inference only supports n_speaker==0")
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
self.zero_g = zero_g
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(
y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau
)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return o_hat, y_mask, (z, z_p, z_hat)