Merge pull request #2499 from coqui-ai/dev

 v0.13.1
This commit is contained in:
Eren Gölge 2023-04-12 16:58:33 +02:00 committed by GitHub
commit bb8d0800f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 66 additions and 21 deletions

View File

@ -1,10 +1,13 @@
<img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
----
### 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin) ## 🐸Coqui.ai News
- 📣 Coqui Studio API is landed on 🐸TTS. You can use the studio voices in combination with 🐸TTS models. [Example](https://github.com/coqui-ai/TTS/edit/dev/README.md#-python-api)
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on Coqui.ai!! [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
- 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin)
<br>
## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
----
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality. 🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
🐸TTS comes with pretrained models, tools for measuring dataset quality and already used in **20+ languages** for products and research projects. 🐸TTS comes with pretrained models, tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
@ -123,6 +126,9 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- HiFiGAN: [paper](https://arxiv.org/abs/2010.05646) - HiFiGAN: [paper](https://arxiv.org/abs/2010.05646)
- UnivNet: [paper](https://arxiv.org/abs/2106.07889) - UnivNet: [paper](https://arxiv.org/abs/2106.07889)
### Voice Conversion
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
You can also help us implement more models. You can also help us implement more models.
## Install TTS ## Install TTS

View File

@ -1 +1 @@
0.13.0 0.13.1

View File

@ -7,6 +7,7 @@ from pathlib import Path
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import requests
from scipy.io import wavfile from scipy.io import wavfile
from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.audio.numpy_transforms import save_wav
@ -65,6 +66,11 @@ class CS_API:
self._speakers = None self._speakers = None
self._check_token() self._check_token()
@staticmethod
def ping_api():
URL = "https://coqui.gateway.scarf.sh/tts/api"
_ = requests.get(URL)
@property @property
def speakers(self): def speakers(self):
if self._speakers is None: if self._speakers is None:
@ -80,12 +86,13 @@ class CS_API:
return ["Neutral", "Happy", "Sad", "Angry", "Dull"] return ["Neutral", "Happy", "Sad", "Angry", "Dull"]
def _check_token(self): def _check_token(self):
self.ping_api()
if self.api_token is None: if self.api_token is None:
self.api_token = os.environ.get("COQUI_STUDIO_TOKEN") self.api_token = os.environ.get("COQUI_STUDIO_TOKEN")
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_token}"} self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_token}"}
if not self.api_token: if not self.api_token:
raise ValueError( raise ValueError(
"No API token found for 🐸Coqui Studio voices - https://coqui.ai.\n" "No API token found for 🐸Coqui Studio voices - https://coqui.ai \n"
"Visit 🔗https://app.coqui.ai/account to get one.\n" "Visit 🔗https://app.coqui.ai/account to get one.\n"
"Set it as an environment variable `export COQUI_STUDIO_TOKEN=<token>`\n" "Set it as an environment variable `export COQUI_STUDIO_TOKEN=<token>`\n"
"" ""
@ -273,8 +280,11 @@ class TTS:
self.csapi = None self.csapi = None
self.model_name = None self.model_name = None
if model_name: if model_name is not None:
self.load_tts_model_by_name(model_name, gpu) if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
if model_path: if model_path:
self.load_tts_model_by_path( self.load_tts_model_by_path(
@ -342,6 +352,7 @@ class TTS:
model_name (str): Model name to load. You can list models by ```tts.models```. model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
""" """
self.model_name = model_name
model_path, config_path, _, _ = self.download_model_by_name(model_name) model_path, config_path, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
@ -565,19 +576,39 @@ class TTS:
def voice_conversion( def voice_conversion(
self, self,
sourve_wav: str, source_wav: str,
target_wav: str, target_wav: str,
): ):
"""Voice conversion with FreeVC. Convert source wav to target speaker. """Voice conversion with FreeVC. Convert source wav to target speaker.
Args:``
source_wav (str):
Path to the source wav file.
target_wav (str):`
Path to the target wav file.
"""
wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
return wav
def voice_conversion_to_file(
self,
source_wav: str,
target_wav: str,
file_path: str = "output.wav",
):
"""Voice conversion with FreeVC. Convert source wav to target speaker.
Args: Args:
source_wav (str): source_wav (str):
Path to the source wav file. Path to the source wav file.
target_wav (str): target_wav (str):
Path to the target wav file. Path to the target wav file.
file_path (str, optional):
Output file path. Defaults to "output.wav".
""" """
wav = self.synthesizer.voice_conversion(source_wav=sourve_wav, target_wav=target_wav) wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav)
return wav save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
return file_path
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None): def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
"""Convert text to speech with voice conversion. """Convert text to speech with voice conversion.

View File

@ -149,7 +149,7 @@ def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
dtype_device = str(spec.dtype) + "_" + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec) mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
mel = amp_to_db(mel) mel = amp_to_db(mel)
@ -176,7 +176,7 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_length) + "_" + dtype_device wnsize_dtype_device = str(win_length) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

View File

@ -269,7 +269,7 @@ def compute_f0(
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
Examples: Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file() >>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig >>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor >>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
@ -310,7 +310,7 @@ def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
Returns: Returns:
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length` np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
Examples: Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file() >>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig >>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor >>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig() >>> conf = BaseAudioConfig()

View File

@ -243,7 +243,7 @@ class AudioProcessor(object):
if self.mel_fmax is not None: if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2 assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel( return librosa.filters.mel(
self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
) )
def _stft_parameters( def _stft_parameters(
@ -569,7 +569,7 @@ class AudioProcessor(object):
np.ndarray: Pitch. np.ndarray: Pitch.
Examples: Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file() >>> WAV_FILE = filename = librosa.example('vibeace')
>>> from TTS.config import BaseAudioConfig >>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor >>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1)
@ -711,7 +711,7 @@ class AudioProcessor(object):
Args: Args:
filename (str): Path to the wav file. filename (str): Path to the wav file.
""" """
return librosa.get_duration(filename) return librosa.get_duration(filename=filename)
@staticmethod @staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:

View File

@ -144,8 +144,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def _build_mel_basis(self): def _build_mel_basis(self):
mel_basis = librosa.filters.mel( mel_basis = librosa.filters.mel(
self.sample_rate, sr=self.sample_rate,
self.n_fft, n_fft=self.n_fft,
n_mels=self.n_mels, n_mels=self.n_mels,
fmin=self.mel_fmin, fmin=self.mel_fmin,
fmax=self.mel_fmax, fmax=self.mel_fmax,

View File

@ -6,7 +6,7 @@ scipy>=1.4.0
torch>=1.7 torch>=1.7
torchaudio torchaudio
soundfile soundfile
librosa==0.8.0 librosa==0.10.0.*
numba==0.55.1;python_version<"3.9" numba==0.55.1;python_version<"3.9"
numba==0.56.4;python_version>="3.9" numba==0.56.4;python_version>="3.9"
inflect==5.6.0 inflect==5.6.0

View File

@ -93,3 +93,11 @@ class TTSTest(unittest.TestCase):
tts = TTS() tts = TTS()
tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts") tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH) tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)
def test_voice_conversion(self): # pylint: disable=no-self-use
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=False)
tts.voice_conversion_to_file(
source_wav=cloning_test_wav_path,
target_wav=cloning_test_wav_path,
file_path=OUTPUT_PATH,
)