diff --git a/README.md b/README.md index 5ca825b6..381a8e95 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ repository are also still a useful source of information. ### Voice Conversion - FreeVC: [paper](https://arxiv.org/abs/2210.15418) +- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479) You can also help us implement more models. diff --git a/TTS/api.py b/TTS/api.py index 250ed1a0..12e82af5 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -155,8 +155,10 @@ class TTS(nn.Module): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ self.model_name = model_name - model_path, config_path, _, _, _ = self.download_model_by_name(model_name) - self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer( + vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu + ) def load_tts_model_by_name(self, model_name: str, gpu: bool = False): """Load one of 🐸TTS models by name. diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index bd445b3a..38fcfd60 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -424,7 +424,7 @@ class ModelManager(object): model_file = None config_file = None for file_name in os.listdir(output_path): - if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: model_file = os.path.join(output_path, file_name) elif file_name == "config.json": config_file = os.path.join(output_path, file_name) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 90af4f48..a158df60 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,6 +1,7 @@ import logging import os import time +from pathlib import Path from typing import List import numpy as np @@ -15,7 +16,9 @@ from TTS.tts.models.vits import Vits from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.configs.openvoice_config import OpenVoiceConfig from TTS.vc.models import setup_model as setup_vc_model +from TTS.vc.models.openvoice import OpenVoice from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input @@ -97,7 +100,7 @@ class Synthesizer(nn.Module): self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self.output_sample_rate = self.vocoder_config.audio["sample_rate"] - if vc_checkpoint: + if vc_checkpoint and model_dir is None: self._load_vc(vc_checkpoint, vc_config, use_cuda) self.output_sample_rate = self.vc_config.audio["output_sample_rate"] @@ -105,6 +108,9 @@ class Synthesizer(nn.Module): if "fairseq" in model_dir: self._load_fairseq_from_dir(model_dir, use_cuda) self.output_sample_rate = self.tts_config.audio["sample_rate"] + elif "openvoice" in model_dir: + self._load_openvoice_from_dir(Path(model_dir), use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] else: self._load_tts_from_dir(model_dir, use_cuda) self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @@ -153,6 +159,19 @@ class Synthesizer(nn.Module): if use_cuda: self.tts_model.cuda() + def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None: + """Load the OpenVoice model from a directory. + + We assume the model knows how to load itself from the directory and + there is a config.json file in the directory. + """ + self.vc_config = OpenVoiceConfig() + self.vc_model = OpenVoice.init_from_config(self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) + self.vc_config = self.vc_model.config + if use_cuda: + self.vc_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """Load the TTS model from a directory. diff --git a/TTS/vc/models/openvoice.py b/TTS/vc/models/openvoice.py new file mode 100644 index 00000000..135b0861 --- /dev/null +++ b/TTS/vc/models/openvoice.py @@ -0,0 +1,320 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Mapping, Optional, Union + +import librosa +import numpy as np +import numpy.typing as npt +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional as F +from trainer.io import load_fsspec + +from TTS.tts.layers.vits.networks import PosteriorEncoder +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio.torch_transforms import wav_to_spec +from TTS.vc.configs.openvoice_config import OpenVoiceConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.models.freevc import Generator, ResidualCouplingBlock + +logger = logging.getLogger(__name__) + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None: + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + torch.nn.utils.parametrizations.weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, embedding_dim) + self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + _memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class OpenVoice(BaseVC): + """ + OpenVoice voice conversion model (inference only). + + Source: https://github.com/myshell-ai/OpenVoice + Paper: https://arxiv.org/abs/2312.01479 + + Paper abstract: + We introduce OpenVoice, a versatile voice cloning approach that requires + only a short audio clip from the reference speaker to replicate their voice and + generate speech in multiple languages. OpenVoice represents a significant + advancement in addressing the following open challenges in the field: 1) + Flexible Voice Style Control. OpenVoice enables granular control over voice + styles, including emotion, accent, rhythm, pauses, and intonation, in addition + to replicating the tone color of the reference speaker. The voice styles are not + directly copied from and constrained by the style of the reference speaker. + Previous approaches lacked the ability to flexibly manipulate voice styles after + cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot + cross-lingual voice cloning for languages not included in the massive-speaker + training set. Unlike previous approaches, which typically require extensive + massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can + clone voices into a new language without any massive-speaker training data for + that language. OpenVoice is also computationally efficient, costing tens of + times less than commercially available APIs that offer even inferior + performance. To foster further research in the field, we have made the source + code and trained model publicly accessible. We also provide qualitative results + in our demo website. Prior to its public release, our internal version of + OpenVoice was used tens of millions of times by users worldwide between May and + October 2023, serving as the backend of MyShell. + """ + + def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None: + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.zero_g = self.args.zero_g + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.n_layers_q = self.args.n_layers_q + self.use_spectral_norm = self.args.use_spectral_norm + self.gin_channels = self.args.gin_channels + self.tau = self.args.tau + + self.spec_channels = config.audio.fft_size // 2 + 1 + + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = PosteriorEncoder( + self.spec_channels, + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + num_layers=16, + cond_channels=self.gin_channels, + ) + + self.flow = ResidualCouplingBlock( + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + n_layers=4, + gin_channels=self.gin_channels, + ) + + self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @staticmethod + def init_from_config(config: OpenVoiceConfig) -> "OpenVoice": + return OpenVoice(config) + + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (list, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = config.num_speakers + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_speakers + + def load_checkpoint( + self, + config: OpenVoiceConfig, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, + ) -> None: + """Map from OpenVoice's config structure.""" + config_path = Path(checkpoint_path).parent / "config.json" + with open(config_path, encoding="utf-8") as f: + config_org = json.load(f) + self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.fft_size = config_org["data"]["filter_length"] + self.config.audio.hop_length = config_org["data"]["hop_length"] + self.config.audio.win_length = config_org["data"]["win_length"] + state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def forward(self) -> None: ... + def train_step(self) -> None: ... + def eval_step(self) -> None: ... + + @staticmethod + def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor: + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x: torch.Tensor, + aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None}, + ) -> dict[str, torch.Tensor]: + """ + Inference pass of the model + + Args: + x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + + Returns: + o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim). + x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, z_hat): A tuple of latent variables. + """ + x_lengths = self._set_x_lengths(x, aux_input) + if "g_src" in aux_input and aux_input["g_src"] is not None: + g_src = aux_input["g_src"] + else: + raise ValueError("aux_input must define g_src") + if "g_tgt" in aux_input and aux_input["g_tgt"] is not None: + g_tgt = aux_input["g_tgt"] + else: + raise ValueError("aux_input must define g_tgt") + z, _m_q, _logs_q, y_mask = self.enc_q( + x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau + ) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) + return { + "model_outputs": o_hat, + "y_mask": y_mask, + "z": z, + "z_p": z_p, + "z_hat": z_hat, + } + + def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor: + """Read and format the input audio.""" + if isinstance(wav, str): + out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0]) + elif isinstance(wav, np.ndarray): + out = torch.from_numpy(wav) + elif isinstance(wav, list): + out = torch.from_numpy(np.array(wav)) + else: + out = wav + return out.to(self.device).float() + + def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + audio_ref = self.load_audio(audio) + y = torch.FloatTensor(audio_ref) + y = y.to(self.device) + y = y.unsqueeze(0) + spec = wav_to_spec( + y, + n_fft=self.config.audio.fft_size, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + center=False, + ).to(self.device) + with torch.no_grad(): + g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1) + + return g, spec + + @torch.inference_mode() + def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]: + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + Output numpy array. + """ + src_se, src_spec = self.extract_se(src) + tgt_se, _ = self.extract_se(tgt) + + aux_input = {"g_src": src_se, "g_tgt": tgt_se} + audio = self.inference(src_spec, aux_input) + return audio["model_outputs"][0, 0].data.cpu().float().numpy() diff --git a/TTS/vc/modules/openvoice/__init__.py b/TTS/vc/modules/openvoice/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/vc/modules/openvoice/models.py b/TTS/vc/modules/openvoice/models.py deleted file mode 100644 index 89a1c3a4..00000000 --- a/TTS/vc/modules/openvoice/models.py +++ /dev/null @@ -1,134 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -from TTS.tts.layers.vits.networks import PosteriorEncoder -from TTS.vc.models.freevc import Generator, ResidualCouplingBlock - - -class ReferenceEncoder(nn.Module): - """ - inputs --- [N, Ty/r, n_mels*r] mels - outputs --- [N, ref_enc_gru_size] - """ - - def __init__(self, spec_channels, gin_channels=0, layernorm=True): - super().__init__() - self.spec_channels = spec_channels - ref_enc_filters = [32, 32, 64, 64, 128, 128] - K = len(ref_enc_filters) - filters = [1] + ref_enc_filters - convs = [ - torch.nn.utils.parametrizations.weight_norm( - nn.Conv2d( - in_channels=filters[i], - out_channels=filters[i + 1], - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1), - ) - ) - for i in range(K) - ] - self.convs = nn.ModuleList(convs) - - out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) - self.gru = nn.GRU( - input_size=ref_enc_filters[-1] * out_channels, - hidden_size=256 // 2, - batch_first=True, - ) - self.proj = nn.Linear(128, gin_channels) - if layernorm: - self.layernorm = nn.LayerNorm(self.spec_channels) - else: - self.layernorm = None - - def forward(self, inputs): - N = inputs.size(0) - - out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] - if self.layernorm is not None: - out = self.layernorm(out) - - for conv in self.convs: - out = conv(out) - out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] - - out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] - T = out.size(1) - N = out.size(0) - out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] - - self.gru.flatten_parameters() - _memory, out = self.gru(out) # out --- [1, N, 128] - - return self.proj(out.squeeze(0)) - - def calculate_channels(self, L, kernel_size, stride, pad, n_convs): - for _ in range(n_convs): - L = (L - kernel_size + 2 * pad) // stride + 1 - return L - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__( - self, - spec_channels, - inter_channels, - hidden_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=0, - gin_channels=256, - zero_g=False, - **kwargs, - ): - super().__init__() - - self.dec = Generator( - inter_channels, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=gin_channels, - ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - cond_channels=gin_channels, - ) - - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - - self.n_speakers = n_speakers - if n_speakers != 0: - raise ValueError("OpenVoice inference only supports n_speaker==0") - self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) - self.zero_g = zero_g - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): - g_src = sid_src - g_tgt = sid_tgt - z, m_q, logs_q, y_mask = self.enc_q( - y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau - ) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) - return o_hat, y_mask, (z, z_p, z_hat)