Merge pull request #2390 from coqui-ai/dev

v0.12.0
This commit is contained in:
Eren Gölge 2023-03-13 12:43:38 +01:00 committed by GitHub
commit 9bb62c570d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 119 additions and 52 deletions

View File

@ -161,16 +161,14 @@ class ResNetSpeakerEncoder(BaseEncoder):
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)

View File

@ -7,8 +7,9 @@ import sys
from pathlib import Path
from threading import Lock
from typing import Union
from urllib.parse import parse_qs
from flask import Flask, render_template, request, send_file
from flask import Flask, render_template, render_template_string, request, send_file
from TTS.config import load_config
from TTS.utils.manage import ModelManager
@ -187,15 +188,59 @@ def tts():
language_idx = request.args.get("language_id", "")
style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
print(" > Language Idx: {}".format(language_idx))
print(f" > Model input: {text}")
print(f" > Speaker Idx: {speaker_idx}")
print(f" > Language Idx: {language_idx}")
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")
# Basic MaryTTS compatibility layer
@app.route("/locales", methods=["GET"])
def mary_tts_api_locales():
"""MaryTTS-compatible /locales endpoint"""
# NOTE: We currently assume there is only one model active at the same time
if args.model_name is not None:
model_details = args.model_name.split("/")
else:
model_details = ["", "en", "", "default"]
return render_template_string("{{ locale }}\n", locale=model_details[1])
@app.route("/voices", methods=["GET"])
def mary_tts_api_voices():
"""MaryTTS-compatible /voices endpoint"""
# NOTE: We currently assume there is only one model active at the same time
if args.model_name is not None:
model_details = args.model_name.split("/")
else:
model_details = ["", "en", "", "default"]
return render_template_string(
"{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u"
)
@app.route("/process", methods=["GET", "POST"])
def mary_tts_api_process():
"""MaryTTS-compatible /process endpoint"""
with lock:
if request.method == "POST":
data = parse_qs(request.get_data(as_text=True))
# NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model
text = data.get("INPUT_TEXT", [""])[0]
else:
text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}")
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")
def main():
app.run(debug=args.debug, host="::", port=args.port)

View File

@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
# multi-speaker settings
num_speakers: int = 0

View File

@ -189,6 +189,8 @@ class TTSDataset(Dataset):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "energy_dataset"):
self.energy_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
@ -856,11 +858,11 @@ class EnergyDataset:
def __getitem__(self, idx):
item = self.samples[idx]
energy = self.compute_or_load(item["audio_file"])
energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_energy:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
energy = self.normalize(energy)
return {"audio_file": item["audio_file"], "energy": energy}
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
def __len__(self):
return len(self.samples)
@ -884,7 +886,7 @@ class EnergyDataset:
if self.normalize_energy:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
energy_mean, energy_std = self.compute_energy_stats(computed_data)
energy_stats = {"mean": energy_mean, "std": energy_std}
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
@ -900,7 +902,7 @@ class EnergyDataset:
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
wav = ap.load_wav(wav_file)
energy = calculate_energy(wav)
energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
if energy_file:
np.save(energy_file, energy)
return energy
@ -931,11 +933,11 @@ class EnergyDataset:
energy[zero_idxs] = 0.0
return energy
def compute_or_load(self, wav_file):
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
else:
@ -943,14 +945,14 @@ class EnergyDataset:
return energy.astype(np.float32)
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
audio_unique_name = [item["audio_unique_name"] for item in batch]
energys = [item["energy"] for item in batch]
energy_lens = [len(item["energy"]) for item in batch]
energy_lens_max = max(energy_lens)
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
for i, energy_len in enumerate(energy_lens):
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level

View File

@ -111,7 +111,7 @@ class BaseTTS(BaseTrainerModel):
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_setences(self, sentence_info):
def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
@ -134,7 +134,7 @@ class BaseTTS(BaseTrainerModel):
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager"):
if self.speaker_manager is not None:
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
@ -147,7 +147,7 @@ class BaseTTS(BaseTrainerModel):
speaker_id = self.speaker_manager.name_to_id[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.name_to_id[language_name]
return {
@ -183,6 +183,7 @@ class BaseTTS(BaseTrainerModel):
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
energy = batch["energy"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -231,6 +232,7 @@ class BaseTTS(BaseTrainerModel):
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"energy": energy,
"language_ids": language_ids,
"audio_unique_names": batch["audio_unique_names"],
}
@ -287,7 +289,7 @@ class BaseTTS(BaseTrainerModel):
loader = None
else:
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
@ -302,7 +304,7 @@ class BaseTTS(BaseTrainerModel):
d_vector_mapping = None
# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
if self.language_manager is not None:
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
else:
language_id_mapping = None
@ -313,6 +315,8 @@ class BaseTTS(BaseTrainerModel):
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
compute_energy=config.get("compute_energy", False),
energy_cache_path=config.get("energy_cache_path", None),
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
@ -424,7 +428,7 @@ class BaseTTS(BaseTrainerModel):
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
if hasattr(self, "language_manager") and self.language_manager is not None:
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path

View File

@ -1,23 +1,4 @@
{% extends "!page.html" %}
{% block scripts %}
{{ super() }}
<!-- DocsQA integration start -->
<script src="https://cdn.jsdelivr.net/npm/qabot@0.4"></script>
<qa-bot
token="qAFjWNovwHUXKKkVhy4AN6tawSwCMfdb3HJNPLVM23ACdrBGxmBNObM="
title="🐸💬TTS Bot"
description="A library for advanced Text-to-Speech generation"
style="bottom: calc(1.25em + 80px);"
>
<template>
<dl>
<dt>You can ask questions about TTS. Try</dt>
<dd>What is VITS?</dd>
<dd>How to train a TTS model?</dd>
<dd>What is the format of training data?</dd>
</dl>
</template>
</qa-bot>
<!-- DocsQA integration end -->
{% endblock %}

View File

@ -28,6 +28,7 @@
formatting_your_dataset
what_makes_a_good_dataset
tts_datasets
marytts
.. toctree::
:maxdepth: 2
@ -48,10 +49,10 @@
models/vits.md
models/forward_tts.md
models/tacotron1-2.md
models/overflow.md
.. toctree::
:maxdepth: 2
:caption: `vocoder` Models
```

0
docs/source/marytts.md Normal file
View File

View File

@ -0,0 +1,36 @@
# Overflow TTS
Neural HMMs are a type of neural transducer recently proposed for
sequence-to-sequence modelling in text-to-speech. They combine the best features
of classic statistical speech synthesis and modern neural TTS, requiring less
data and fewer training updates, and are less prone to gibberish output caused
by neural attention failures. In this paper, we combine neural HMM TTS with
normalising flows for describing the highly non-Gaussian distribution of speech
acoustics. The result is a powerful, fully probabilistic model of durations and
acoustics that can be trained using exact maximum likelihood. Compared to
dominant flow-based acoustic models, our approach integrates autoregression for
improved modelling of long-range dependences such as utterance-level prosody.
Experiments show that a system based on our proposal gives more accurate
pronunciations and better subjective speech quality than comparable methods,
whilst retaining the original advantages of neural HMMs. Audio examples and code
are available at https://shivammehta25.github.io/OverFlow/.
## Important resources & papers
- HMM: https://de.wikipedia.org/wiki/Hidden_Markov_Model
- OverflowTTS paper: https://arxiv.org/abs/2211.06892
- Neural HMM: https://arxiv.org/abs/2108.13320
- Audio Samples: https://shivammehta25.github.io/OverFlow/
## OverflowConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.overflow_config.OverflowConfig
:members:
```
## Overflow Model
```{eval-rst}
.. autoclass:: TTS.tts.models.overflow.Overflow
:members:
```

View File

@ -1,6 +1,6 @@
# core deps
numpy==1.21.6;python_version<"3.10"
numpy==1.22.4;python_version=="3.10"
numpy;python_version=="3.10"
cython==0.29.28
scipy>=1.4.0
torch>=1.7

View File

@ -38,7 +38,7 @@ config = Fastspeech2Config(
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,

View File

@ -38,7 +38,7 @@ config = Fastspeech2Config(
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,