Merge pull request #183 from idiap/openvoice

Add OpenVoice VC models
This commit is contained in:
Enno Hermann 2024-12-03 19:13:22 +01:00 committed by GitHub
commit 9ae0b27f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 671 additions and 65 deletions

View File

@ -2,6 +2,8 @@ repos:
- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v5.0.0
hooks:
- id: check-json
files: "TTS/.models.json"
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

View File

@ -1,13 +1,12 @@
## 🐸Coqui TTS News
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
- 📣 [OpenVoice](https://github.com/myshell-ai/OpenVoice) models now available for voice conversion.
- 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms.
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
- 📣 ⓍTTSv2 is here with 17 languages and better performance across the board. ⓍTTS can stream with <200ms latency.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
- 📣 ⓍTTS can now stream with <200ms latency.
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 You can use [Fairseq models in ~1100 languages](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/>
@ -121,6 +120,7 @@ repository are also still a useful source of information.
### Voice Conversion
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479)
You can also help us implement more models.
@ -244,8 +244,14 @@ tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progr
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```
#### Example voice cloning together with the voice conversion model.
This way, you can clone voices by using any model in 🐸TTS.
Other available voice conversion models:
- `voice_conversion_models/multilingual/multi-dataset/openvoice_v1`
- `voice_conversion_models/multilingual/multi-dataset/openvoice_v2`
#### Example voice cloning together with the default voice conversion model.
This way, you can clone voices by using any model in 🐸TTS. The FreeVC model is
used for voice conversion after synthesizing speech.
```python
@ -412,4 +418,6 @@ $ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<mode
|- (same)
|- vocoder/ (Vocoder models.)
|- (same)
|- vc/ (Voice conversion models.)
|- (same)
```

View File

@ -931,6 +931,28 @@
"license": "MIT",
"commit": null
}
},
"multi-dataset": {
"openvoice_v1": {
"hf_url": [
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/config.json",
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"author": "MyShell.ai",
"license": "MIT",
"commit": null
},
"openvoice_v2": {
"hf_url": [
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/config.json",
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"author": "MyShell.ai",
"license": "MIT",
"commit": null
}
}
}
}

View File

@ -155,8 +155,10 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of 🐸TTS models by name.
@ -355,15 +357,17 @@ class TTS(nn.Module):
target_wav (str):`
Path to the target wav file.
"""
wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
return wav
if self.voice_converter is None:
msg = "The selected model does not support voice conversion."
raise RuntimeError(msg)
return self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
def voice_conversion_to_file(
self,
source_wav: str,
target_wav: str,
file_path: str = "output.wav",
):
) -> str:
"""Voice conversion with FreeVC. Convert source wav to target speaker.
Args:

View File

@ -407,18 +407,18 @@ def main():
# load models
synthesizer = Synthesizer(
tts_path,
tts_config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
encoder_config_path,
vc_path,
vc_config_path,
model_dir,
args.voice_dir,
tts_checkpoint=tts_path,
tts_config_path=tts_config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=language_ids_file_path,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=encoder_path,
encoder_config=encoder_config_path,
vc_checkpoint=vc_path,
vc_config=vc_config_path,
model_dir=model_dir,
voice_dir=args.voice_dir,
).to(device)
# query speaker ids of a multi-speaker model.

View File

@ -256,7 +256,7 @@ class PosteriorEncoder(nn.Module):
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
def forward(self, x, x_lengths, g=None, tau=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@ -268,5 +268,5 @@ class PosteriorEncoder(nn.Module):
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask

View File

@ -424,7 +424,7 @@ class ModelManager(object):
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)

View File

@ -1,6 +1,7 @@
import logging
import os
import time
from pathlib import Path
from typing import List
import numpy as np
@ -15,7 +16,9 @@ from TTS.tts.models.vits import Vits
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models import setup_model as setup_vc_model
from TTS.vc.models.openvoice import OpenVoice
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@ -25,6 +28,7 @@ logger = logging.getLogger(__name__)
class Synthesizer(nn.Module):
def __init__(
self,
*,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",
@ -91,23 +95,20 @@ class Synthesizer(nn.Module):
if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint:
if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
@ -136,6 +137,7 @@ class Synthesizer(nn.Module):
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
@ -150,9 +152,24 @@ class Synthesizer(nn.Module):
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda:
self.tts_model.cuda()
def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
"""Load the OpenVoice model from a directory.
We assume the model knows how to load itself from the directory and
there is a config.json file in the directory.
"""
self.vc_config = OpenVoiceConfig()
self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if use_cuda:
self.vc_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
@ -160,6 +177,7 @@ class Synthesizer(nn.Module):
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
@ -181,6 +199,7 @@ class Synthesizer(nn.Module):
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")
@ -218,6 +237,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)

View File

@ -229,7 +229,7 @@ class FreeVCConfig(BaseVCConfig):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters.
Example:

View File

@ -0,0 +1,201 @@
from dataclasses import dataclass, field
from typing import Optional
from coqpit import Coqpit
from TTS.vc.configs.shared_configs import BaseVCConfig
@dataclass
class OpenVoiceAudioConfig(Coqpit):
"""Audio configuration
Args:
input_sample_rate (int):
The sampling rate of the input waveform.
output_sample_rate (int):
The sampling rate of the output waveform.
fft_size (int):
The length of the filter.
hop_length (int):
The hop length.
win_length (int):
The window length.
"""
input_sample_rate: int = field(default=22050)
output_sample_rate: int = field(default=22050)
fft_size: int = field(default=1024)
hop_length: int = field(default=256)
win_length: int = field(default=1024)
@dataclass
class OpenVoiceArgs(Coqpit):
"""OpenVoice model arguments.
zero_g (bool):
Whether to zero the gradients.
inter_channels (int):
The number of channels in the intermediate layers.
hidden_channels (int):
The number of channels in the hidden layers.
filter_channels (int):
The number of channels in the filter layers.
n_heads (int):
The number of attention heads.
n_layers (int):
The number of layers.
kernel_size (int):
The size of the kernel.
p_dropout (float):
The dropout probability.
resblock (str):
The type of residual block.
resblock_kernel_sizes (List[int]):
The kernel sizes for the residual blocks.
resblock_dilation_sizes (List[List[int]]):
The dilation sizes for the residual blocks.
upsample_rates (List[int]):
The upsample rates.
upsample_initial_channel (int):
The number of channels in the initial upsample layer.
upsample_kernel_sizes (List[int]):
The kernel sizes for the upsample layers.
n_layers_q (int):
The number of layers in the quantization network.
use_spectral_norm (bool):
Whether to use spectral normalization.
gin_channels (int):
The number of channels in the global conditioning vector.
tau (float):
Tau parameter for the posterior encoder
"""
zero_g: bool = field(default=True)
inter_channels: int = field(default=192)
hidden_channels: int = field(default=192)
filter_channels: int = field(default=768)
n_heads: int = field(default=2)
n_layers: int = field(default=6)
kernel_size: int = field(default=3)
p_dropout: float = field(default=0.1)
resblock: str = field(default="1")
resblock_kernel_sizes: list[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates: list[int] = field(default_factory=lambda: [8, 8, 2, 2])
upsample_initial_channel: int = field(default=512)
upsample_kernel_sizes: list[int] = field(default_factory=lambda: [16, 16, 4, 4])
n_layers_q: int = field(default=3)
use_spectral_norm: bool = field(default=False)
gin_channels: int = field(default=256)
tau: float = field(default=0.3)
@dataclass
class OpenVoiceConfig(BaseVCConfig):
"""Defines parameters for OpenVoice VC model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (OpenVoiceArgs):
Model architecture arguments. Defaults to `OpenVoiceArgs()`.
audio (OpenVoiceAudioConfig):
Audio processing configuration. Defaults to `OpenVoiceAudioConfig()`.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters.
Example:
>>> from TTS.vc.configs.openvoice_config import OpenVoiceConfig
>>> config = OpenVoiceConfig()
"""
model: str = "openvoice"
# model specific params
model_args: OpenVoiceArgs = field(default_factory=OpenVoiceArgs)
audio: OpenVoiceAudioConfig = field(default_factory=OpenVoiceAudioConfig)
# optimizer
# TODO with training support
# loss params
# TODO with training support
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
speakers_file: Optional[str] = None
speaker_embedding_channels: int = 256
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: Optional[list[str]] = None
d_vector_dim: Optional[int] = None
def __post_init__(self) -> None:
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val

View File

@ -7,7 +7,7 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.layers.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1

View File

@ -5,7 +5,7 @@ from typing import Optional, Union
import librosa
import numpy as np
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
from TTS.vc.layers.freevc.speaker_encoder.hparams import (
audio_norm_target_dBFS,
mel_n_channels,
mel_window_length,

View File

@ -7,8 +7,8 @@ import torch
from torch import nn
from trainer.io import load_fsspec
from TTS.vc.modules.freevc.speaker_encoder import audio
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
from TTS.vc.layers.freevc.speaker_encoder import audio
from TTS.vc.layers.freevc.speaker_encoder.hparams import (
mel_n_channels,
mel_window_step,
model_embedding_size,

View File

@ -6,7 +6,7 @@ import torch
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
from TTS.vc.layers.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)

View File

@ -17,7 +17,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from TTS.vc.modules.freevc.wavlm.modules import (
from TTS.vc.layers.freevc.wavlm.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GLU_Linear,

View File

@ -12,17 +12,16 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
import TTS.vc.layers.freevc.modules as modules
from TTS.tts.layers.vits.discriminator import DiscriminatorS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.layers.freevc.commons import init_weights, rand_slice_segments
from TTS.vc.layers.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.layers.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.layers.freevc.wavlm import get_wavlm
from TTS.vc.models.base_vc import BaseVC
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
logger = logging.getLogger(__name__)
@ -385,7 +384,7 @@ class FreeVC(BaseVC):
z_p = self.flow(z, spec_mask, g=g)
# Randomly slice z and compute o using dec
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
z_slice, ids_slice = rand_slice_segments(z, spec_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

320
TTS/vc/models/openvoice.py Normal file
View File

@ -0,0 +1,320 @@
import json
import logging
import os
from pathlib import Path
from typing import Any, Mapping, Optional, Union
import librosa
import numpy as np
import numpy.typing as npt
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import functional as F
from trainer.io import load_fsspec
from TTS.tts.layers.vits.networks import PosteriorEncoder
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio.torch_transforms import wav_to_spec
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models.base_vc import BaseVC
from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
logger = logging.getLogger(__name__)
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None:
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
torch.nn.utils.parametrizations.weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, embedding_dim)
self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None:
out = self.layernorm(out)
for conv in self.convs:
out = conv(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
_memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class OpenVoice(BaseVC):
"""
OpenVoice voice conversion model (inference only).
Source: https://github.com/myshell-ai/OpenVoice
Paper: https://arxiv.org/abs/2312.01479
Paper abstract:
We introduce OpenVoice, a versatile voice cloning approach that requires
only a short audio clip from the reference speaker to replicate their voice and
generate speech in multiple languages. OpenVoice represents a significant
advancement in addressing the following open challenges in the field: 1)
Flexible Voice Style Control. OpenVoice enables granular control over voice
styles, including emotion, accent, rhythm, pauses, and intonation, in addition
to replicating the tone color of the reference speaker. The voice styles are not
directly copied from and constrained by the style of the reference speaker.
Previous approaches lacked the ability to flexibly manipulate voice styles after
cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot
cross-lingual voice cloning for languages not included in the massive-speaker
training set. Unlike previous approaches, which typically require extensive
massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can
clone voices into a new language without any massive-speaker training data for
that language. OpenVoice is also computationally efficient, costing tens of
times less than commercially available APIs that offer even inferior
performance. To foster further research in the field, we have made the source
code and trained model publicly accessible. We also provide qualitative results
in our demo website. Prior to its public release, our internal version of
OpenVoice was used tens of millions of times by users worldwide between May and
October 2023, serving as the backend of MyShell.
"""
def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None:
super().__init__(config, None, speaker_manager, None)
self.init_multispeaker(config)
self.zero_g = self.args.zero_g
self.inter_channels = self.args.inter_channels
self.hidden_channels = self.args.hidden_channels
self.filter_channels = self.args.filter_channels
self.n_heads = self.args.n_heads
self.n_layers = self.args.n_layers
self.kernel_size = self.args.kernel_size
self.p_dropout = self.args.p_dropout
self.resblock = self.args.resblock
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
self.upsample_rates = self.args.upsample_rates
self.upsample_initial_channel = self.args.upsample_initial_channel
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
self.n_layers_q = self.args.n_layers_q
self.use_spectral_norm = self.args.use_spectral_norm
self.gin_channels = self.args.gin_channels
self.tau = self.args.tau
self.spec_channels = config.audio.fft_size // 2 + 1
self.dec = Generator(
self.inter_channels,
self.resblock,
self.resblock_kernel_sizes,
self.resblock_dilation_sizes,
self.upsample_rates,
self.upsample_initial_channel,
self.upsample_kernel_sizes,
gin_channels=self.gin_channels,
)
self.enc_q = PosteriorEncoder(
self.spec_channels,
self.inter_channels,
self.hidden_channels,
kernel_size=5,
dilation_rate=1,
num_layers=16,
cond_channels=self.gin_channels,
)
self.flow = ResidualCouplingBlock(
self.inter_channels,
self.hidden_channels,
kernel_size=5,
dilation_rate=1,
n_layers=4,
gin_channels=self.gin_channels,
)
self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@staticmethod
def init_from_config(config: OpenVoiceConfig) -> "OpenVoice":
return OpenVoice(config)
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
Args:
config (Coqpit): Model configuration.
data (list, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.num_spks = config.num_speakers
if self.speaker_manager:
self.num_spks = self.speaker_manager.num_speakers
def load_checkpoint(
self,
config: OpenVoiceConfig,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None:
"""Map from OpenVoice's config structure."""
config_path = Path(checkpoint_path).parent / "config.json"
with open(config_path, encoding="utf-8") as f:
config_org = json.load(f)
self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"]
self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"]
self.config.audio.fft_size = config_org["data"]["filter_length"]
self.config.audio.hop_length = config_org["data"]["hop_length"]
self.config.audio.win_length = config_org["data"]["win_length"]
state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"], strict=strict)
if eval:
self.eval()
def forward(self) -> None: ...
def train_step(self) -> None: ...
def eval_step(self) -> None: ...
@staticmethod
def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor:
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
@torch.no_grad()
def inference(
self,
x: torch.Tensor,
aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None},
) -> dict[str, torch.Tensor]:
"""
Inference pass of the model
Args:
x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
Returns:
o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim).
x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len).
(z, z_p, z_hat): A tuple of latent variables.
"""
x_lengths = self._set_x_lengths(x, aux_input)
if "g_src" in aux_input and aux_input["g_src"] is not None:
g_src = aux_input["g_src"]
else:
raise ValueError("aux_input must define g_src")
if "g_tgt" in aux_input and aux_input["g_tgt"] is not None:
g_tgt = aux_input["g_tgt"]
else:
raise ValueError("aux_input must define g_tgt")
z, _m_q, _logs_q, y_mask = self.enc_q(
x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau
)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return {
"model_outputs": o_hat,
"y_mask": y_mask,
"z": z,
"z_p": z_p,
"z_hat": z_hat,
}
def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor:
"""Read and format the input audio."""
if isinstance(wav, str):
out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0])
elif isinstance(wav, np.ndarray):
out = torch.from_numpy(wav)
elif isinstance(wav, list):
out = torch.from_numpy(np.array(wav))
else:
out = wav
return out.to(self.device).float()
def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
audio_ref = self.load_audio(audio)
y = torch.FloatTensor(audio_ref)
y = y.to(self.device)
y = y.unsqueeze(0)
spec = wav_to_spec(
y,
n_fft=self.config.audio.fft_size,
hop_length=self.config.audio.hop_length,
win_length=self.config.audio.win_length,
center=False,
).to(self.device)
with torch.no_grad():
g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1)
return g, spec
@torch.inference_mode()
def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]:
"""
Voice conversion pass of the model.
Args:
src (str or torch.Tensor): Source utterance.
tgt (str or torch.Tensor): Target utterance.
Returns:
Output numpy array.
"""
src_se, src_spec = self.extract_se(src)
tgt_se, _ = self.extract_se(tgt)
aux_input = {"g_src": src_se, "g_tgt": tgt_se}
audio = self.inference(src_spec, aux_input)
return audio["model_outputs"][0, 0].data.cpu().float().numpy()

View File

@ -23,7 +23,7 @@ class SynthesizerTest(unittest.TestCase):
tts_root_path = get_tests_input_path()
tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth")
tts_config = os.path.join(tts_root_path, "dummy_model_config.json")
synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None)
synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config)
synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self):

View File

@ -22,31 +22,19 @@ BATCH_SIZE = 3
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
input_lengths[-1] = 30 * config.audio["hop_length"]
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
return mel, spec, spec_lengths, waveform
@staticmethod
def _create_inputs_inference():
source_wav = torch.rand(16000)
source_wav = torch.rand(15999)
target_wav = torch.rand(16000)
return source_wav, target_wav
@staticmethod
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
def test_methods(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
@ -69,7 +57,7 @@ class TestFreeVC(unittest.TestCase):
model.train()
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
_, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
@ -86,7 +74,7 @@ class TestFreeVC(unittest.TestCase):
model = FreeVC(config).to(device)
model.eval()
_, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
mel, _, _, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
@ -108,8 +96,8 @@ class TestFreeVC(unittest.TestCase):
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
assert (
output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
), f"{output_wav.shape} != {source_wav.shape}"
output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
), f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}"
def test_train_step(self): ...

View File

@ -0,0 +1,42 @@
import os
import unittest
import torch
from tests import get_tests_input_path
from TTS.vc.models.openvoice import OpenVoice, OpenVoiceConfig
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
c = OpenVoiceConfig()
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
class TestOpenVoice(unittest.TestCase):
@staticmethod
def _create_inputs_inference():
source_wav = torch.rand(16100)
target_wav = torch.rand(16000)
return source_wav, target_wav
def test_load_audio(self):
config = OpenVoiceConfig()
model = OpenVoice(config).to(device)
wav = model.load_audio(WAV_FILE)
wav2 = model.load_audio(wav)
assert all(torch.isclose(wav, wav2))
def test_voice_conversion(self):
config = OpenVoiceConfig()
model = OpenVoice(config).to(device)
model.eval()
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
assert (
output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
), f"{output_wav.shape} != {source_wav.shape}"