Merge pull request #914 from coqui-ai/dev

v0.4.2
This commit is contained in:
Eren Gölge 2021-12-08 16:41:44 +01:00 committed by GitHub
commit 7f1a23787e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 373 additions and 152 deletions

View File

@ -6,33 +6,53 @@ labels: bug
assignees: '' assignees: ''
--- ---
<!-- Welcome to the 🐸TTS!
We are excited to see your interest, and appreciate your support! --->
## 🐛 Description
Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support! <!-- A clear and concise description of what the bug is. -->
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file. ### To Reproduce
If you've found a bug, please provide the following information: <!--
Please share your code to reproduce the error. Issues fixed faster if you can provide a working example.
**Describe the bug** The best place for sharing code is colab. https://colab.research.google.com/
A clear and concise description of what the bug is. So we can directly run your code and reproduce the issue.
In the worse case provide steps to reproduce the behaviour.
**To Reproduce**
Steps to reproduce the behavior:
1. Run the following command '...' 1. Run the following command '...'
2. ... 2. ...
3. See error 3. See error
-->
**Expected behavior** ### Expected behavior
A clear and concise description of what you expected to happen.
**Environment (please complete the following information):** <!-- Write down what the expected behaviour -->
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **PyTorch or TensorFlow version (use command below)**:
- **Python version**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
**Additional context** ### Environment
Add any other context about the problem here.
<!--
You can either run `TTS/bin/collect_env_info.py`
```bash
wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_details.py
python collect_env_details.py
```
or fill in the fields below manually.
-->
- 🐸TTS Version (e.g., 1.3.0):
- PyTorch Version (e.g., 1.8)
- Python version:
- OS (e.g., Linux):
- CUDA/cuDNN version:
- GPU models and configuration:
- How you installed PyTorch (`conda`, `pip`, source):
- Any other relevant information:
### Additional context
<!-- Add any other context about the problem here. -->

View File

@ -6,21 +6,20 @@ labels: feature request
assignees: '' assignees: ''
--- ---
<!-- Welcome to the 🐸TTS project!
We are excited to see your interest, and appreciate your support! --->
**🚀 Feature Description**
Welcome to the 🐸TTS project! We are excited to see your interest, and appreciate your support! <!--A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] -->
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file. **Solution**
If you have a feature request, then please provide the following information: <!-- A clear and concise description of what you want to happen. -->
**Is your feature request related to a problem? Please describe.** **Alternative Solutions**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like** <!-- A clear and concise description of any alternative solutions or features you've considered. -->
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context** **Additional context**
Add any other context or screenshots about the feature request here.
<!-- Add any other context or screenshots about the feature request here. -->

View File

@ -98,7 +98,7 @@
"fast_pitch":{ "fast_pitch":{
"description": "FastPitch model trained on VCTK dataseset.", "description": "FastPitch model trained on VCTK dataseset.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.4.0/tts_models--en--vctk--fast_pitch.zip",
"default_vocoder": "vocoder_models/en/vctk/hifigan_v2", "default_vocoder": null,
"commit": "bdab788d", "commit": "bdab788d",
"author": "Eren @erogol", "author": "Eren @erogol",
"license": "CC BY-NC-ND 4.0", "license": "CC BY-NC-ND 4.0",

View File

@ -0,0 +1,48 @@
"""Get detailed info about the working environment."""
import os
import platform
import sys
import numpy
import torch
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
import json
import TTS
def system_info():
return {
"OS": platform.system(),
"architecture": platform.architecture(),
"version": platform.version(),
"processor": platform.processor(),
"python": platform.python_version(),
}
def cuda_info():
return {
"GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
"available": torch.cuda.is_available(),
"version": torch.version.cuda,
}
def package_info():
return {
"numpy": numpy.__version__,
"PyTorch_version": torch.__version__,
"PyTorch_debug": torch.version.debug,
"TTS": TTS.__version__,
}
def main():
details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()}
print(json.dumps(details, indent=4, sort_keys=True))
if __name__ == "__main__":
main()

View File

@ -254,7 +254,7 @@ def main():
print(" > Text: {}".format(args.text)) print(" > Text: {}".format(args.text))
# kick it # kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav) wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
# save the results # save the results
print(" > Saving output to {}".format(args.out_path)) print(" > Saving output to {}".format(args.out_path))

View File

@ -103,7 +103,7 @@ synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
) )
use_multi_speaker = hasattr(synthesizer.tts_model, "speaker_manager") and synthesizer.tts_model.num_speakers > 1 use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
# TODO: set this from SpeakerManager # TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False) use_gst = synthesizer.tts_config.get("use_gst", False)

View File

@ -284,8 +284,8 @@ class Trainer:
self.optimizer = self.get_optimizer(self.model, self.config) self.optimizer = self.get_optimizer(self.model, self.config)
# CALLBACK # CALLBACK
self.callbacks = TrainerCallback(self) self.callbacks = TrainerCallback()
self.callbacks.on_init_start() self.callbacks.on_init_start(self)
# init AMP # init AMP
if self.use_amp_scaler: if self.use_amp_scaler:
@ -324,7 +324,7 @@ class Trainer:
num_params = count_parameters(self.model) num_params = count_parameters(self.model)
print("\n > Model has {} parameters".format(num_params)) print("\n > Model has {} parameters".format(num_params))
self.callbacks.on_init_end() self.callbacks.on_init_end(self)
@staticmethod @staticmethod
def parse_argv(args: Union[Coqpit, List]): def parse_argv(args: Union[Coqpit, List]):
@ -677,7 +677,7 @@ class Trainer:
Returns: Returns:
Tuple[Dict, Dict]: Model outputs and losses. Tuple[Dict, Dict]: Model outputs and losses.
""" """
self.callbacks.on_train_step_start() self.callbacks.on_train_step_start(self)
# format data # format data
batch = self.format_batch(batch) batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time loader_time = time.time() - loader_start_time
@ -792,7 +792,7 @@ class Trainer:
self.dashboard_logger.flush() self.dashboard_logger.flush()
self.total_steps_done += 1 self.total_steps_done += 1
self.callbacks.on_train_step_end() self.callbacks.on_train_step_end(self)
return outputs, loss_dict return outputs, loss_dict
def train_epoch(self) -> None: def train_epoch(self) -> None:
@ -983,7 +983,7 @@ class Trainer:
if self.num_gpus > 1: if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training # let all processes sync up before starting with a new epoch of training
dist.barrier() dist.barrier()
self.callbacks.on_epoch_start() self.callbacks.on_epoch_start(self)
self.keep_avg_train = KeepAverage() self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch self.epochs_done = epoch
@ -999,7 +999,7 @@ class Trainer:
) )
if self.args.rank in [None, 0]: if self.args.rank in [None, 0]:
self.save_best_model() self.save_best_model()
self.callbacks.on_epoch_end() self.callbacks.on_epoch_end(self)
def fit(self) -> None: def fit(self) -> None:
"""Where the ✨magic✨ happens...""" """Where the ✨magic✨ happens..."""
@ -1008,7 +1008,7 @@ class Trainer:
if self.args.rank == 0: if self.args.rank == 0:
self.dashboard_logger.finish() self.dashboard_logger.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt() self.callbacks.on_keyboard_interrupt(self)
# if the output folder is empty remove the run. # if the output folder is empty remove the run.
remove_experiment_folder(self.output_path) remove_experiment_folder(self.output_path)
# clear the DDP processes # clear the DDP processes

View File

@ -110,6 +110,7 @@ class FastSpeechConfig(BaseTTSConfig):
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False) model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False)
# multi-speaker settings # multi-speaker settings
num_speakers: int = 0
speakers_file: str = None speakers_file: str = None
use_speaker_embedding: bool = False use_speaker_embedding: bool = False
use_d_vector_file: bool = False use_d_vector_file: bool = False
@ -142,7 +143,7 @@ class FastSpeechConfig(BaseTTSConfig):
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
# dataset configs # dataset configs
compute_f0: bool = True compute_f0: bool = False
f0_cache_path: str = None f0_cache_path: str = None
# testing # testing

View File

@ -91,6 +91,8 @@ def load_tts_samples(
for idx, ins in enumerate(meta_data_eval_all): for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins[1]].strip()
meta_data_eval_all[idx].append(attn_file) meta_data_eval_all[idx].append(attn_file)
# set none for the next iter
formatter = None
return meta_data_train_all, meta_data_eval_all return meta_data_train_all, meta_data_eval_all
@ -110,3 +112,18 @@ def _get_formatter_by_name(name):
"""Returns the respective preprocessing function.""" """Returns the respective preprocessing function."""
thismodule = sys.modules[__name__] thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower()) return getattr(thismodule, name.lower())
def find_unique_chars(data_samples, verbose=True):
texts = "".join(item[0] for item in data_samples)
chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = [c.lower() for c in chars]
chars_force_lower = set(chars_force_lower)
if verbose:
print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}")
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
return chars_force_lower

View File

@ -60,7 +60,13 @@ def mozilla_de(root_path, meta_file):
def mailabs(root_path, meta_files=None): def mailabs(root_path, meta_files=None):
"""Normalizes M-AI-Labs meta data files to TTS format""" """Normalizes M-AI-Labs meta data files to TTS format
Args:
root_path (str): root folder of the MAILAB language folder.
meta_files (str): list of meta files to be used in the training. If None, finds all the csv files
recursively. Defaults to None
"""
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/") speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
if meta_files is None: if meta_files is None:
csv_files = glob(root_path + "/**/metadata.csv", recursive=True) csv_files = glob(root_path + "/**/metadata.csv", recursive=True)

View File

@ -266,7 +266,7 @@ class StochasticDurationPredictor(nn.Module):
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows: for flow in flows:
z = torch.flip(z, [1]) z = torch.flip(z, [1])
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)

View File

@ -219,7 +219,7 @@ class BaseTTS(BaseModel):
use_phonemes=config.use_phonemes, use_phonemes=config.use_phonemes,
phoneme_language=config.phoneme_language, phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars, enable_eos_bos=config.enable_eos_bos_chars,
use_noise_augment=not is_eval, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,

View File

@ -250,11 +250,11 @@ def synthesis(
# GST processing # GST processing
style_mel = None style_mel = None
custom_symbols = None custom_symbols = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: if style_wav:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
elif CONFIG.has("gst") and CONFIG.gst and not style_wav:
if CONFIG.gst.gst_style_input_weights:
style_mel = CONFIG.gst.gst_style_input_weights
if hasattr(model, "make_symbols"): if hasattr(model, "make_symbols"):
custom_symbols = model.make_symbols(CONFIG) custom_symbols = model.make_symbols(CONFIG)
# preprocess the given text # preprocess the given text

View File

@ -5,6 +5,7 @@ import re
from typing import Dict, List from typing import Dict, List
import gruut import gruut
from gruut_ipa import IPA
from TTS.tts.utils.text import cleaners from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -32,7 +33,7 @@ PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
def text2phone(text, language, use_espeak_phonemes=False): def text2phone(text, language, use_espeak_phonemes=False, keep_stress=False):
"""Convert graphemes to phonemes. """Convert graphemes to phonemes.
Parameters: Parameters:
text (str): text to phonemize text (str): text to phonemize
@ -51,37 +52,45 @@ def text2phone(text, language, use_espeak_phonemes=False):
ph = japanese_text_to_phonemes(text) ph = japanese_text_to_phonemes(text)
return ph return ph
if gruut.is_language_supported(language): if not gruut.is_language_supported(language):
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
# Use gruut for phonemization # Use gruut for phonemization
phonemizer_args = { ph_list = []
"remove_stress": True, for sentence in gruut.sentences(text, lang=language, espeak=use_espeak_phonemes):
"ipa_minor_breaks": False, # don't replace commas/semi-colons with IPA | for word in sentence:
"ipa_major_breaks": False, # don't replace periods with IPA ‖ if word.is_break:
} # Use actual character for break phoneme (e.g., comma)
if ph_list:
# Join with previous word
ph_list[-1].append(word.text)
else:
# First word is punctuation
ph_list.append([word.text])
elif word.phonemes:
# Add phonemes for word
word_phonemes = []
if use_espeak_phonemes: for word_phoneme in word.phonemes:
# Use a lexicon/g2p model train on eSpeak IPA instead of gruut IPA. if not keep_stress:
# This is intended for backwards compatibility with TTS<=v0.0.13 # Remove primary/secondary stress
# pre-trained models. word_phoneme = IPA.without_stress(word_phoneme)
phonemizer_args["model_prefix"] = "espeak"
ph_list = gruut.text_to_phonemes( word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
text,
lang=language, if word_phoneme:
return_format="word_phonemes", # Flatten phonemes
phonemizer_args=phonemizer_args, word_phonemes.extend(word_phoneme)
)
if word_phonemes:
ph_list.append(word_phonemes)
# Join and re-split to break apart dipthongs, suprasegmentals, etc. # Join and re-split to break apart dipthongs, suprasegmentals, etc.
ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list] ph_words = ["|".join(word_phonemes) for word_phonemes in ph_list]
ph = "| ".join(ph_words) ph = "| ".join(ph_words)
# Fix a few phonemes
ph = ph.translate(GRUUT_TRANS_TABLE)
return ph return ph
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
def intersperse(sequence, token): def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1) result = [token] * (len(sequence) * 2 + 1)

View File

@ -674,7 +674,7 @@ class AudioProcessor(object):
return f0 return f0
### Audio Processing ### ### Audio Processing ###
def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int:
"""Find the last point without silence at the end of a audio signal. """Find the last point without silence at the end of a audio signal.
Args: Args:
@ -687,7 +687,7 @@ class AudioProcessor(object):
""" """
window_length = int(self.sample_rate * min_silence_sec) window_length = int(self.sample_rate * min_silence_sec)
hop_length = int(window_length / 4) hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db) threshold = self._db_to_amp(-self.trim_db)
for x in range(hop_length, len(wav) - window_length, hop_length): for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x : x + window_length]) < threshold: if np.max(wav[x : x + window_length]) < threshold:
return x + hop_length return x + hop_length

View File

@ -1,75 +1,105 @@
class TrainerCallback: class TrainerCallback:
def __init__(self, trainer): @staticmethod
super().__init__() def on_init_start(trainer) -> None:
self.trainer = trainer if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_start"):
trainer.model.module.on_init_start(trainer)
else:
if hasattr(trainer.model, "on_init_start"):
trainer.model.on_init_start(trainer)
def on_init_start(self) -> None: if hasattr(trainer.criterion, "on_init_start"):
if hasattr(self.trainer.model, "on_init_start"): trainer.criterion.on_init_start(trainer)
self.trainer.model.on_init_start(self.trainer)
if hasattr(self.trainer.criterion, "on_init_start"): if hasattr(trainer.optimizer, "on_init_start"):
self.trainer.criterion.on_init_start(self.trainer) trainer.optimizer.on_init_start(trainer)
if hasattr(self.trainer.optimizer, "on_init_start"): @staticmethod
self.trainer.optimizer.on_init_start(self.trainer) def on_init_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_init_end"):
trainer.model.module.on_init_end(trainer)
else:
if hasattr(trainer.model, "on_init_end"):
trainer.model.on_init_end(trainer)
def on_init_end(self) -> None: if hasattr(trainer.criterion, "on_init_end"):
if hasattr(self.trainer.model, "on_init_end"): trainer.criterion.on_init_end(trainer)
self.trainer.model.on_init_end(self.trainer)
if hasattr(self.trainer.criterion, "on_init_end"): if hasattr(trainer.optimizer, "on_init_end"):
self.trainer.criterion.on_init_end(self.trainer) trainer.optimizer.on_init_end(trainer)
if hasattr(self.trainer.optimizer, "on_init_end"): @staticmethod
self.trainer.optimizer.on_init_end(self.trainer) def on_epoch_start(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_start"):
trainer.model.module.on_epoch_start(trainer)
else:
if hasattr(trainer.model, "on_epoch_start"):
trainer.model.on_epoch_start(trainer)
def on_epoch_start(self) -> None: if hasattr(trainer.criterion, "on_epoch_start"):
if hasattr(self.trainer.model, "on_epoch_start"): trainer.criterion.on_epoch_start(trainer)
self.trainer.model.on_epoch_start(self.trainer)
if hasattr(self.trainer.criterion, "on_epoch_start"): if hasattr(trainer.optimizer, "on_epoch_start"):
self.trainer.criterion.on_epoch_start(self.trainer) trainer.optimizer.on_epoch_start(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_start"): @staticmethod
self.trainer.optimizer.on_epoch_start(self.trainer) def on_epoch_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_epoch_end"):
trainer.model.module.on_epoch_end(trainer)
else:
if hasattr(trainer.model, "on_epoch_end"):
trainer.model.on_epoch_end(trainer)
def on_epoch_end(self) -> None: if hasattr(trainer.criterion, "on_epoch_end"):
if hasattr(self.trainer.model, "on_epoch_end"): trainer.criterion.on_epoch_end(trainer)
self.trainer.model.on_epoch_end(self.trainer)
if hasattr(self.trainer.criterion, "on_epoch_end"): if hasattr(trainer.optimizer, "on_epoch_end"):
self.trainer.criterion.on_epoch_end(self.trainer) trainer.optimizer.on_epoch_end(trainer)
if hasattr(self.trainer.optimizer, "on_epoch_end"): @staticmethod
self.trainer.optimizer.on_epoch_end(self.trainer) def on_train_step_start(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_start"):
trainer.model.module.on_train_step_start(trainer)
else:
if hasattr(trainer.model, "on_train_step_start"):
trainer.model.on_train_step_start(trainer)
def on_train_step_start(self) -> None: if hasattr(trainer.criterion, "on_train_step_start"):
if hasattr(self.trainer.model, "on_train_step_start"): trainer.criterion.on_train_step_start(trainer)
self.trainer.model.on_train_step_start(self.trainer)
if hasattr(self.trainer.criterion, "on_train_step_start"): if hasattr(trainer.optimizer, "on_train_step_start"):
self.trainer.criterion.on_train_step_start(self.trainer) trainer.optimizer.on_train_step_start(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_start"): @staticmethod
self.trainer.optimizer.on_train_step_start(self.trainer) def on_train_step_end(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_train_step_end"):
trainer.model.module.on_train_step_end(trainer)
else:
if hasattr(trainer.model, "on_train_step_end"):
trainer.model.on_train_step_end(trainer)
def on_train_step_end(self) -> None: if hasattr(trainer.criterion, "on_train_step_end"):
trainer.criterion.on_train_step_end(trainer)
if hasattr(self.trainer.model, "on_train_step_end"): if hasattr(trainer.optimizer, "on_train_step_end"):
self.trainer.model.on_train_step_end(self.trainer) trainer.optimizer.on_train_step_end(trainer)
if hasattr(self.trainer.criterion, "on_train_step_end"): @staticmethod
self.trainer.criterion.on_train_step_end(self.trainer) def on_keyboard_interrupt(trainer) -> None:
if hasattr(trainer.model, "module"):
if hasattr(trainer.model.module, "on_keyboard_interrupt"):
trainer.model.module.on_keyboard_interrupt(trainer)
else:
if hasattr(trainer.model, "on_keyboard_interrupt"):
trainer.model.on_keyboard_interrupt(trainer)
if hasattr(self.trainer.optimizer, "on_train_step_end"): if hasattr(trainer.criterion, "on_keyboard_interrupt"):
self.trainer.optimizer.on_train_step_end(self.trainer) trainer.criterion.on_keyboard_interrupt(trainer)
def on_keyboard_interrupt(self) -> None: if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
if hasattr(self.trainer.model, "on_keyboard_interrupt"): trainer.optimizer.on_keyboard_interrupt(trainer)
self.trainer.model.on_keyboard_interrupt(self.trainer)
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)

View File

@ -74,6 +74,8 @@ class Synthesizer(object):
if vocoder_checkpoint: if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"] self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
else:
print(" > Using Griffin-Lim as no vocoder model defined")
@staticmethod @staticmethod
def _get_segmenter(lang: str): def _get_segmenter(lang: str):
@ -265,6 +267,7 @@ class Synthesizer(object):
waveform = waveform.squeeze() waveform = waveform.squeeze()
# trim silence # trim silence
if self.tts_config.audio["do_trim_silence"] is True:
waveform = trim_silence(waveform, self.ap) waveform = trim_silence(waveform, self.ap)
wavs += list(waveform) wavs += list(waveform)

View File

@ -310,6 +310,7 @@ class GAN(BaseVocoder):
data_items: List, data_items: List,
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = 0, # pylint: disable=unused-argument
): ):
"""Initiate and return the GAN dataloader. """Initiate and return the GAN dataloader.

View File

@ -93,13 +93,13 @@ them and fine-tune it for your own dataset. This will help you in two main ways:
```bash ```bash
CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \ CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts --restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/model_file.pth.tar
``` ```
```bash ```bash
CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py \ CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py \
--config_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/config.json \ --config_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/config.json \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts --restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/model_file.pth.tar
``` ```
As stated above, you can also use command-line arguments to change the model configuration. As stated above, you can also use command-line arguments to change the model configuration.
@ -107,7 +107,7 @@ them and fine-tune it for your own dataset. This will help you in two main ways:
```bash ```bash
CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \ CUDA_VISIBLE_DEVICES="0" python recipes/ljspeech/glow_tts/train_glowtts.py \
--restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts --restore_path /home/ubuntu/.local/share/tts/tts_models--en--ljspeech--glow-tts/model_file.pth.tar
--coqpit.run_name "glow-tts-finetune" \ --coqpit.run_name "glow-tts-finetune" \
--coqpit.lr 0.00001 --coqpit.lr 0.00001
``` ```

View File

@ -19,15 +19,15 @@ Let's assume you created the audio clips and their transcription. You can collec
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text. You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text.
We recommend the following format delimited by `||`. We recommend the following format delimited by `||`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc.
``` ```
# metadata.txt # metadata.txt
audio1.wav || This is my sentence. audio1||This is my sentence.
audio2.wav || This is maybe my sentence. audio2||This is maybe my sentence.
audio3.wav || This is certainly my sentence. audio3||This is certainly my sentence.
audio4.wav || Let this be your sentence. audio4||Let this be your sentence.
... ...
``` ```

View File

@ -35,7 +35,7 @@ config = Tacotron2Config( # This is the config that is saved for the future use
test_delay_epochs=-1, test_delay_epochs=-1,
r=2, r=2,
# gradual_training=[[0, 6, 48], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]], # gradual_training=[[0, 6, 48], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]],
double_decoder_consistency=False, double_decoder_consistency=True,
epochs=1000, epochs=1000,
text_cleaner="phoneme_cleaners", text_cleaner="phoneme_cleaners",
use_phonemes=True, use_phonemes=True,

View File

@ -0,0 +1,87 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.tacotron2 import Tacotron2
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
audio_config = BaseAudioConfig(
sample_rate=22050,
resample=False, # Resample to 22050 Hz. It slows down training. Use `TTS/bin/resample.py` to pre-resample and set this False for faster training.
do_trim_silence=True,
trim_db=23.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
preemphasis=0.0,
)
config = Tacotron2Config( # This is the config that is saved for the future use
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
r=2,
# gradual_training=[[0, 6, 48], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]],
double_decoder_consistency=False,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=150,
print_eval=False,
mixed_precision=True,
sort_by_audio_len=True,
min_seq_len=14800,
max_seq_len=22050 * 10, # 44k is the original sampling rate before resampling, corresponds to 10 seconds of audio
output_path=output_path,
datasets=[dataset_config],
use_speaker_embedding=True, # set this to enable multi-sepeaker training
decoder_ssim_alpha=0.0, # disable ssim losses that causes NaN for some runs.
postnet_ssim_alpha=0.0,
postnet_diff_spec_alpha=0.0,
decoder_diff_spec_alpha=0.0,
attention_norm="softmax",
optimizer="Adam",
lr_scheduler=None,
lr=3e-5,
)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
# init model
model = Tacotron2(config, speaker_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -23,6 +23,6 @@ coqpit
mecab-python3==1.0.3 mecab-python3==1.0.3
unidic-lite==1.0.8 unidic-lite==1.0.8
# gruut+supported langs # gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0 gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0 fsspec>=2021.04.0
pyworld pyworld

View File

@ -9,12 +9,12 @@ LANG = "en-us"
EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" EXAMPLE_TEXT = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
EXPECTED_PHONEMES = "ɹ|iː|s|ə|n|t| ɹ|ᵻ|s|ɜː|tʃ| æ|t| h|ɑːɹ|v|ɚ|d| h|æ|z| ʃ|oʊ|n| m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| f|ɔːɹ| æ|z| l|ɪ|ɾ|əl| æ|z| eɪ|t| w|iː|k|s| k|æ|n| æ|k|tʃ|uː|əl|i| ɪ|ŋ|k|ɹ|iː|s| ,| ð|ə| ɡ|ɹ|eɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑːɹ|t|s| ʌ|v| ð|ə| b|ɹ|eɪ|n| ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| f|ɔːɹ| ɪ|m|oʊ|ʃ|ə|n|əl| ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| æ|n|d| l|ɜː|n|ɪ|ŋ| !" EXPECTED_PHONEMES = "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ| f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s|,| ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ| ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l| f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ|!"
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TextProcessingTextCase(unittest.TestCase): class TextProcessingTestCase(unittest.TestCase):
"""Tests for text to phoneme conversion""" """Tests for text to phoneme conversion"""
def test_phoneme_to_sequence(self): def test_phoneme_to_sequence(self):
@ -40,7 +40,7 @@ class TextProcessingTextCase(unittest.TestCase):
sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
text_hat_with_params = sequence_to_phoneme(sequence) text_hat_with_params = sequence_to_phoneme(sequence)
gt = "biː ɐ vɔɪs , nɑːt ɐn ! ɛkoʊ ?" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
self.assertEqual(text_hat, text_hat_with_params) self.assertEqual(text_hat, text_hat_with_params)
@ -51,7 +51,7 @@ class TextProcessingTextCase(unittest.TestCase):
sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
text_hat_with_params = sequence_to_phoneme(sequence) text_hat_with_params = sequence_to_phoneme(sequence)
gt = "biː ɐ vɔɪs , nɑːt ɐn ! ɛkoʊ" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
self.assertEqual(text_hat, text_hat_with_params) self.assertEqual(text_hat, text_hat_with_params)
@ -62,7 +62,7 @@ class TextProcessingTextCase(unittest.TestCase):
sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
text_hat_with_params = sequence_to_phoneme(sequence) text_hat_with_params = sequence_to_phoneme(sequence)
gt = "biː ɐ vɔɪs , nɑːt ɐn ɛkoʊ !" gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
self.assertEqual(text_hat, text_hat_with_params) self.assertEqual(text_hat, text_hat_with_params)
@ -73,7 +73,7 @@ class TextProcessingTextCase(unittest.TestCase):
sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True) sequence = phoneme_to_sequence(text, text_cleaner, LANG, add_blank=add_blank, use_espeak_phonemes=True)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
text_hat_with_params = sequence_to_phoneme(sequence) text_hat_with_params = sequence_to_phoneme(sequence)
gt = "biː ɐ vɔɪs , nɑːt ɐn ! ɛkoʊ ." gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
self.assertEqual(text_hat, text_hat_with_params) self.assertEqual(text_hat, text_hat_with_params)
@ -86,7 +86,7 @@ class TextProcessingTextCase(unittest.TestCase):
) )
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
text_hat_with_params = sequence_to_phoneme(sequence) text_hat_with_params = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs , nɑːt ɐn ! ɛkoʊ .~" gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
self.assertEqual(text_hat, text_hat_with_params) self.assertEqual(text_hat, text_hat_with_params)