mirror of https://github.com/coqui-ai/TTS.git
commit
9bb62c570d
|
@ -161,16 +161,14 @@ class ResNetSpeakerEncoder(BaseEncoder):
|
|||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
|
|
|
@ -7,8 +7,9 @@ import sys
|
|||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import Flask, render_template, request, send_file
|
||||
from flask import Flask, render_template, render_template_string, request, send_file
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
@ -187,15 +188,59 @@ def tts():
|
|||
language_idx = request.args.get("language_id", "")
|
||||
style_wav = request.args.get("style_wav", "")
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
print(" > Model input: {}".format(text))
|
||||
print(" > Speaker Idx: {}".format(speaker_idx))
|
||||
print(" > Language Idx: {}".format(language_idx))
|
||||
print(f" > Model input: {text}")
|
||||
print(f" > Speaker Idx: {speaker_idx}")
|
||||
print(f" > Language Idx: {language_idx}")
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
return send_file(out, mimetype="audio/wav")
|
||||
|
||||
|
||||
# Basic MaryTTS compatibility layer
|
||||
|
||||
|
||||
@app.route("/locales", methods=["GET"])
|
||||
def mary_tts_api_locales():
|
||||
"""MaryTTS-compatible /locales endpoint"""
|
||||
# NOTE: We currently assume there is only one model active at the same time
|
||||
if args.model_name is not None:
|
||||
model_details = args.model_name.split("/")
|
||||
else:
|
||||
model_details = ["", "en", "", "default"]
|
||||
return render_template_string("{{ locale }}\n", locale=model_details[1])
|
||||
|
||||
|
||||
@app.route("/voices", methods=["GET"])
|
||||
def mary_tts_api_voices():
|
||||
"""MaryTTS-compatible /voices endpoint"""
|
||||
# NOTE: We currently assume there is only one model active at the same time
|
||||
if args.model_name is not None:
|
||||
model_details = args.model_name.split("/")
|
||||
else:
|
||||
model_details = ["", "en", "", "default"]
|
||||
return render_template_string(
|
||||
"{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u"
|
||||
)
|
||||
|
||||
|
||||
@app.route("/process", methods=["GET", "POST"])
|
||||
def mary_tts_api_process():
|
||||
"""MaryTTS-compatible /process endpoint"""
|
||||
with lock:
|
||||
if request.method == "POST":
|
||||
data = parse_qs(request.get_data(as_text=True))
|
||||
# NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model
|
||||
text = data.get("INPUT_TEXT", [""])[0]
|
||||
else:
|
||||
text = request.args.get("INPUT_TEXT", "")
|
||||
print(f" > Model input: {text}")
|
||||
wavs = synthesizer.tts(text)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
return send_file(out, mimetype="audio/wav")
|
||||
|
||||
|
||||
def main():
|
||||
app.run(debug=args.debug, host="::", port=args.port)
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
|
|||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
|
||||
|
||||
# multi-speaker settings
|
||||
num_speakers: int = 0
|
||||
|
|
|
@ -189,6 +189,8 @@ class TTSDataset(Dataset):
|
|||
self._samples = new_samples
|
||||
if hasattr(self, "f0_dataset"):
|
||||
self.f0_dataset.samples = new_samples
|
||||
if hasattr(self, "energy_dataset"):
|
||||
self.energy_dataset.samples = new_samples
|
||||
if hasattr(self, "phoneme_dataset"):
|
||||
self.phoneme_dataset.samples = new_samples
|
||||
|
||||
|
@ -856,11 +858,11 @@ class EnergyDataset:
|
|||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
energy = self.compute_or_load(item["audio_file"])
|
||||
energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
|
||||
if self.normalize_energy:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
energy = self.normalize(energy)
|
||||
return {"audio_file": item["audio_file"], "energy": energy}
|
||||
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -884,7 +886,7 @@ class EnergyDataset:
|
|||
|
||||
if self.normalize_energy:
|
||||
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
|
||||
energy_mean, energy_std = self.compute_energy_stats(computed_data)
|
||||
energy_stats = {"mean": energy_mean, "std": energy_std}
|
||||
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
|
||||
|
||||
|
@ -900,7 +902,7 @@ class EnergyDataset:
|
|||
@staticmethod
|
||||
def _compute_and_save_energy(ap, wav_file, energy_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
energy = calculate_energy(wav)
|
||||
energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
|
||||
if energy_file:
|
||||
np.save(energy_file, energy)
|
||||
return energy
|
||||
|
@ -931,11 +933,11 @@ class EnergyDataset:
|
|||
energy[zero_idxs] = 0.0
|
||||
return energy
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
def compute_or_load(self, wav_file, audio_unique_name):
|
||||
"""
|
||||
compute energy and return a numpy array of energy values
|
||||
"""
|
||||
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
|
||||
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
|
||||
if not os.path.exists(energy_file):
|
||||
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
|
||||
else:
|
||||
|
@ -943,14 +945,14 @@ class EnergyDataset:
|
|||
return energy.astype(np.float32)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
audio_file = [item["audio_file"] for item in batch]
|
||||
audio_unique_name = [item["audio_unique_name"] for item in batch]
|
||||
energys = [item["energy"] for item in batch]
|
||||
energy_lens = [len(item["energy"]) for item in batch]
|
||||
energy_lens_max = max(energy_lens)
|
||||
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
|
||||
for i, energy_len in enumerate(energy_lens):
|
||||
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
|
||||
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
|
||||
return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
|
|
|
@ -111,7 +111,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_setences(self, sentence_info):
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
|
@ -134,7 +134,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if self.speaker_manager is not None:
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
|
@ -147,7 +147,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.name_to_id[language_name]
|
||||
|
||||
return {
|
||||
|
@ -183,6 +183,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
energy = batch["energy"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
@ -231,6 +232,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"energy": energy,
|
||||
"language_ids": language_ids,
|
||||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
@ -287,7 +289,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
||||
|
@ -302,7 +304,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
d_vector_mapping = None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
if self.language_manager is not None:
|
||||
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
@ -313,6 +315,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
compute_energy=config.get("compute_energy", False),
|
||||
energy_cache_path=config.get("energy_cache_path", None),
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
|
@ -424,7 +428,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
if self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
|
|
|
@ -1,23 +1,4 @@
|
|||
{% extends "!page.html" %}
|
||||
{% block scripts %}
|
||||
{{ super() }}
|
||||
<!-- DocsQA integration start -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/qabot@0.4"></script>
|
||||
|
||||
<qa-bot
|
||||
token="qAFjWNovwHUXKKkVhy4AN6tawSwCMfdb3HJNPLVM23ACdrBGxmBNObM="
|
||||
title="🐸💬TTS Bot"
|
||||
description="A library for advanced Text-to-Speech generation"
|
||||
style="bottom: calc(1.25em + 80px);"
|
||||
>
|
||||
<template>
|
||||
<dl>
|
||||
<dt>You can ask questions about TTS. Try</dt>
|
||||
<dd>What is VITS?</dd>
|
||||
<dd>How to train a TTS model?</dd>
|
||||
<dd>What is the format of training data?</dd>
|
||||
</dl>
|
||||
</template>
|
||||
</qa-bot>
|
||||
<!-- DocsQA integration end -->
|
||||
{% endblock %}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
formatting_your_dataset
|
||||
what_makes_a_good_dataset
|
||||
tts_datasets
|
||||
marytts
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
@ -48,10 +49,10 @@
|
|||
models/vits.md
|
||||
models/forward_tts.md
|
||||
models/tacotron1-2.md
|
||||
models/overflow.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: `vocoder` Models
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Overflow TTS
|
||||
|
||||
Neural HMMs are a type of neural transducer recently proposed for
|
||||
sequence-to-sequence modelling in text-to-speech. They combine the best features
|
||||
of classic statistical speech synthesis and modern neural TTS, requiring less
|
||||
data and fewer training updates, and are less prone to gibberish output caused
|
||||
by neural attention failures. In this paper, we combine neural HMM TTS with
|
||||
normalising flows for describing the highly non-Gaussian distribution of speech
|
||||
acoustics. The result is a powerful, fully probabilistic model of durations and
|
||||
acoustics that can be trained using exact maximum likelihood. Compared to
|
||||
dominant flow-based acoustic models, our approach integrates autoregression for
|
||||
improved modelling of long-range dependences such as utterance-level prosody.
|
||||
Experiments show that a system based on our proposal gives more accurate
|
||||
pronunciations and better subjective speech quality than comparable methods,
|
||||
whilst retaining the original advantages of neural HMMs. Audio examples and code
|
||||
are available at https://shivammehta25.github.io/OverFlow/.
|
||||
|
||||
|
||||
## Important resources & papers
|
||||
- HMM: https://de.wikipedia.org/wiki/Hidden_Markov_Model
|
||||
- OverflowTTS paper: https://arxiv.org/abs/2211.06892
|
||||
- Neural HMM: https://arxiv.org/abs/2108.13320
|
||||
- Audio Samples: https://shivammehta25.github.io/OverFlow/
|
||||
|
||||
|
||||
## OverflowConfig
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.configs.overflow_config.OverflowConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
## Overflow Model
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.overflow.Overflow
|
||||
:members:
|
||||
```
|
|
@ -1,6 +1,6 @@
|
|||
# core deps
|
||||
numpy==1.21.6;python_version<"3.10"
|
||||
numpy==1.22.4;python_version=="3.10"
|
||||
numpy;python_version=="3.10"
|
||||
cython==0.29.28
|
||||
scipy>=1.4.0
|
||||
torch>=1.7
|
||||
|
@ -39,4 +39,4 @@ gruut[de]==2.2.3
|
|||
# deps for korean
|
||||
jamo
|
||||
nltk
|
||||
g2pkk>=0.1.1
|
||||
g2pkk>=0.1.1
|
||||
|
|
|
@ -38,7 +38,7 @@ config = Fastspeech2Config(
|
|||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
energy_cache_path="tests/data/ljspeech/energy_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
|
|
|
@ -38,7 +38,7 @@ config = Fastspeech2Config(
|
|||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
energy_cache_path="tests/data/ljspeech/energy_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
|
|
Loading…
Reference in New Issue