mirror of https://github.com/coqui-ai/TTS.git
Enable `custom_symbols` in text processing
Models can define their own custom symbols lists with custom `make_symbols()`
This commit is contained in:
parent
bd4e29b4dd
commit
003e5579e8
|
@ -23,7 +23,9 @@ class TTSDataset(Dataset):
|
||||||
ap: AudioProcessor,
|
ap: AudioProcessor,
|
||||||
meta_data: List[List],
|
meta_data: List[List],
|
||||||
characters: Dict = None,
|
characters: Dict = None,
|
||||||
|
custom_symbols: List = None,
|
||||||
add_blank: bool = False,
|
add_blank: bool = False,
|
||||||
|
return_wav: bool = False,
|
||||||
batch_group_size: int = 0,
|
batch_group_size: int = 0,
|
||||||
min_seq_len: int = 0,
|
min_seq_len: int = 0,
|
||||||
max_seq_len: int = float("inf"),
|
max_seq_len: int = float("inf"),
|
||||||
|
@ -54,9 +56,14 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||||
|
|
||||||
|
custom_symbols (list): List of custom symbols used for converting texts to sequences. Models using its own
|
||||||
|
set of symbols need to pass it here. Defaults to `None`.
|
||||||
|
|
||||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||||
models achieve better results. Defaults to false.
|
models achieve better results. Defaults to false.
|
||||||
|
|
||||||
|
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||||
|
|
||||||
batch_group_size (int): Range of batch randomization after sorting
|
batch_group_size (int): Range of batch randomization after sorting
|
||||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||||
batch. Set 0 to disable. Defaults to 0.
|
batch. Set 0 to disable. Defaults to 0.
|
||||||
|
@ -95,10 +102,12 @@ class TTSDataset(Dataset):
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.compute_linear_spec = compute_linear_spec
|
self.compute_linear_spec = compute_linear_spec
|
||||||
|
self.return_wav = return_wav
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.characters = characters
|
self.characters = characters
|
||||||
|
self.custom_symbols = custom_symbols
|
||||||
self.add_blank = add_blank
|
self.add_blank = add_blank
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
|
@ -109,6 +118,7 @@ class TTSDataset(Dataset):
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
|
self.rescue_item_idx = 1
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -128,13 +138,21 @@ class TTSDataset(Dataset):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
def _generate_and_cache_phoneme_sequence(
|
||||||
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
|
):
|
||||||
"""generate a phoneme sequence from text.
|
"""generate a phoneme sequence from text.
|
||||||
since the usage is for subsequent caching, we never add bos and
|
since the usage is for subsequent caching, we never add bos and
|
||||||
eos chars here. Instead we add those dynamically later; based on the
|
eos chars here. Instead we add those dynamically later; based on the
|
||||||
config option."""
|
config option."""
|
||||||
phonemes = phoneme_to_sequence(
|
phonemes = phoneme_to_sequence(
|
||||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
text,
|
||||||
|
[cleaners],
|
||||||
|
language=language,
|
||||||
|
enable_eos_bos=False,
|
||||||
|
custom_symbols=custom_symbols,
|
||||||
|
tp=characters,
|
||||||
|
add_blank=add_blank,
|
||||||
)
|
)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
np.save(cache_path, phonemes)
|
np.save(cache_path, phonemes)
|
||||||
|
@ -142,7 +160,7 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_or_generate_phoneme_sequence(
|
def _load_or_generate_phoneme_sequence(
|
||||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, custom_symbols, characters, add_blank
|
||||||
):
|
):
|
||||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||||
|
|
||||||
|
@ -153,12 +171,12 @@ class TTSDataset(Dataset):
|
||||||
phonemes = np.load(cache_path)
|
phonemes = np.load(cache_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, characters, add_blank
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
)
|
)
|
||||||
except (ValueError, IOError):
|
except (ValueError, IOError):
|
||||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||||
text, cache_path, cleaners, language, characters, add_blank
|
text, cache_path, cleaners, language, custom_symbols, characters, add_blank
|
||||||
)
|
)
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||||
|
@ -189,13 +207,19 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
text_to_sequence(
|
||||||
|
text,
|
||||||
|
[self.cleaners],
|
||||||
|
custom_symbols=self.custom_symbols,
|
||||||
|
tp=self.characters,
|
||||||
|
add_blank=self.add_blank,
|
||||||
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,7 +233,7 @@ class TTSDataset(Dataset):
|
||||||
# return a different sample if the phonemized
|
# return a different sample if the phonemized
|
||||||
# text is longer than the threshold
|
# text is longer than the threshold
|
||||||
# TODO: find a better fix
|
# TODO: find a better fix
|
||||||
return self.load_data(100)
|
return self.load_data(self.rescue_item_idx)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"text": text,
|
"text": text,
|
||||||
|
@ -238,7 +262,13 @@ class TTSDataset(Dataset):
|
||||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||||
text, *_ = item
|
text, *_ = item
|
||||||
sequence = np.asarray(
|
sequence = np.asarray(
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
text_to_sequence(
|
||||||
|
text,
|
||||||
|
[self.cleaners],
|
||||||
|
custom_symbols=self.custom_symbols,
|
||||||
|
tp=self.characters,
|
||||||
|
add_blank=self.add_blank,
|
||||||
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
self.items[idx][0] = sequence
|
self.items[idx][0] = sequence
|
||||||
|
@ -249,6 +279,7 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
self.phoneme_language,
|
||||||
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
]
|
]
|
||||||
|
@ -347,6 +378,14 @@ class TTSDataset(Dataset):
|
||||||
|
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
|
# lengths adjusted by the reduction factor
|
||||||
|
mel_lengths_adjusted = [
|
||||||
|
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||||
|
if m.shape[1] % self.outputs_per_step
|
||||||
|
else m.shape[1]
|
||||||
|
for m in mel
|
||||||
|
]
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
@ -385,6 +424,20 @@ class TTSDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
linear = None
|
linear = None
|
||||||
|
|
||||||
|
# format waveforms
|
||||||
|
wav_padded = None
|
||||||
|
if self.return_wav:
|
||||||
|
wav_lengths = [w.shape[0] for w in wav]
|
||||||
|
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
|
||||||
|
wav_lengths = torch.LongTensor(wav_lengths)
|
||||||
|
wav_padded = torch.zeros(len(batch), 1, max_wav_len)
|
||||||
|
for i, w in enumerate(wav):
|
||||||
|
mel_length = mel_lengths_adjusted[i]
|
||||||
|
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
|
||||||
|
w = w[: mel_length * self.ap.hop_length]
|
||||||
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
|
||||||
|
wav_padded.transpose_(1, 2)
|
||||||
|
|
||||||
# collate attention alignments
|
# collate attention alignments
|
||||||
if batch[0]["attn"] is not None:
|
if batch[0]["attn"] is not None:
|
||||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||||
|
@ -409,6 +462,7 @@ class TTSDataset(Dataset):
|
||||||
d_vectors,
|
d_vectors,
|
||||||
speaker_ids,
|
speaker_ids,
|
||||||
attns,
|
attns,
|
||||||
|
wav_padded,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -15,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def text_to_seq(text, CONFIG):
|
def text_to_seq(text, CONFIG, custom_symbols=None):
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
# text ot phonemes to sequence vector
|
# text ot phonemes to sequence vector
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
|
@ -28,16 +28,14 @@ def text_to_seq(text, CONFIG):
|
||||||
tp=CONFIG.characters,
|
tp=CONFIG.characters,
|
||||||
add_blank=CONFIG.add_blank,
|
add_blank=CONFIG.add_blank,
|
||||||
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
use_espeak_phonemes=CONFIG.use_espeak_phonemes,
|
||||||
|
custom_symbols=custom_symbols,
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
seq = np.asarray(
|
seq = np.asarray(
|
||||||
text_to_sequence(
|
text_to_sequence(
|
||||||
text,
|
text, text_cleaner, tp=CONFIG.characters, add_blank=CONFIG.add_blank, custom_symbols=custom_symbols
|
||||||
text_cleaner,
|
|
||||||
tp=CONFIG.characters,
|
|
||||||
add_blank=CONFIG.add_blank,
|
|
||||||
),
|
),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
|
@ -229,13 +227,16 @@ def synthesis(
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
custom_symbols = None
|
||||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||||
if isinstance(style_wav, dict):
|
if isinstance(style_wav, dict):
|
||||||
style_mel = style_wav
|
style_mel = style_wav
|
||||||
else:
|
else:
|
||||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||||
|
if hasattr(model, "make_symbols"):
|
||||||
|
custom_symbols = model.make_symbols(CONFIG)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
text_inputs = text_to_seq(text, CONFIG)
|
text_inputs = text_to_seq(text, CONFIG, custom_symbols=custom_symbols)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
|
@ -274,15 +275,18 @@ def synthesis(
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
if use_griffin_lim:
|
if hasattr(model, "END2END") and model.END2END:
|
||||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
wav = model_outputs.squeeze(0)
|
||||||
# trim silence
|
else:
|
||||||
if do_trim_silence:
|
if use_griffin_lim:
|
||||||
wav = trim_silence(wav, ap)
|
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||||
|
# trim silence
|
||||||
|
if do_trim_silence:
|
||||||
|
wav = trim_silence(wav, ap)
|
||||||
return_dict = {
|
return_dict = {
|
||||||
"wav": wav,
|
"wav": wav,
|
||||||
"alignments": alignments,
|
"alignments": alignments,
|
||||||
"model_outputs": model_outputs,
|
|
||||||
"text_inputs": text_inputs,
|
"text_inputs": text_inputs,
|
||||||
|
"outputs": outputs,
|
||||||
}
|
}
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
# adapted from https://github.com/keithito/tacotron
|
# adapted from https://github.com/keithito/tacotron
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
from typing import Dict, List
|
||||||
|
|
||||||
import gruut
|
import gruut
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from TTS.tts.utils.text import cleaners
|
from TTS.tts.utils.text import cleaners
|
||||||
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||||
|
@ -22,6 +21,7 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||||
|
|
||||||
_symbols = symbols
|
_symbols = symbols
|
||||||
_phonemes = phonemes
|
_phonemes = phonemes
|
||||||
|
|
||||||
# Regular expression matching text enclosed in curly braces:
|
# Regular expression matching text enclosed in curly braces:
|
||||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ def text2phone(text, language, use_espeak_phonemes=False):
|
||||||
# Fix a few phonemes
|
# Fix a few phonemes
|
||||||
ph = ph.translate(GRUUT_TRANS_TABLE)
|
ph = ph.translate(GRUUT_TRANS_TABLE)
|
||||||
|
|
||||||
print(" > Phonemes: {}".format(ph))
|
# print(" > Phonemes: {}".format(ph))
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
raise ValueError(f" [!] Language {language} is not supported for phonemization.")
|
||||||
|
@ -106,13 +106,38 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||||
|
|
||||||
|
|
||||||
def phoneme_to_sequence(
|
def phoneme_to_sequence(
|
||||||
text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False, use_espeak_phonemes=False
|
text: str,
|
||||||
):
|
cleaner_names: List[str],
|
||||||
|
language: str,
|
||||||
|
enable_eos_bos: bool = False,
|
||||||
|
custom_symbols: List[str] = None,
|
||||||
|
tp: Dict = None,
|
||||||
|
add_blank: bool = False,
|
||||||
|
use_espeak_phonemes: bool = False,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Converts a string of phonemes to a sequence of IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): string to convert to a sequence
|
||||||
|
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||||
|
language (str): text language key for phonemization.
|
||||||
|
enable_eos_bos (bool): whether to append the end-of-sentence and beginning-of-sentence tokens.
|
||||||
|
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||||
|
add_blank (bool): option to add a blank token between each token.
|
||||||
|
use_espeak_phonemes (bool): use espeak based lexicons to convert phonemes to sequenc
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: List of integers corresponding to the symbols in the text
|
||||||
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _phonemes_to_id, _phonemes
|
global _phonemes_to_id, _phonemes
|
||||||
|
|
||||||
if tp:
|
if tp:
|
||||||
_, _phonemes = make_symbols(**tp)
|
_, _phonemes = make_symbols(**tp)
|
||||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||||
|
elif custom_symbols is not None:
|
||||||
|
_phonemes = custom_symbols
|
||||||
|
_phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
|
@ -127,7 +152,6 @@ def phoneme_to_sequence(
|
||||||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||||
if add_blank:
|
if add_blank:
|
||||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||||
|
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,27 +173,31 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
||||||
return result.replace("}{", " ")
|
return result.replace("}{", " ")
|
||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
def text_to_sequence(
|
||||||
|
text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
|
||||||
|
) -> List[int]:
|
||||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
|
|
||||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
|
||||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: string to convert to a sequence
|
text (str): string to convert to a sequence
|
||||||
cleaner_names: names of the cleaner functions to run the text through
|
cleaner_names (List[str]): names of the cleaner functions to run the text through
|
||||||
tp: dictionary of character parameters to use a custom character set.
|
tp (Dict): dictionary of character parameters to use a custom character set.
|
||||||
|
add_blank (bool): option to add a blank token between each token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of integers corresponding to the symbols in the text
|
List[int]: List of integers corresponding to the symbols in the text
|
||||||
"""
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global _symbol_to_id, _symbols
|
global _symbol_to_id, _symbols
|
||||||
if tp:
|
if tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||||
|
elif custom_symbols is not None:
|
||||||
|
_symbols = custom_symbols
|
||||||
|
_symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
|
|
||||||
# Check for curly braces and treat their contents as ARPAbet:
|
# Check for curly braces and treat their contents as ARPAbet:
|
||||||
while text:
|
while text:
|
||||||
m = _CURLY_RE.match(text)
|
m = _CURLY_RE.match(text)
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=True,
|
compute_linear_spec=True,
|
||||||
|
return_wav=True,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
meta_data=items,
|
meta_data=items,
|
||||||
characters=c.characters,
|
characters=c.characters,
|
||||||
|
@ -75,16 +76,26 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_target = data[6]
|
stop_target = data[6]
|
||||||
item_idx = data[7]
|
item_idx = data[7]
|
||||||
|
wavs = data[11]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
|
||||||
assert isinstance(speaker_name[0], str)
|
assert isinstance(speaker_name[0], str)
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.audio["num_mels"]
|
assert mel_input.shape[2] == c.audio["num_mels"]
|
||||||
|
assert (
|
||||||
|
wavs.shape[1] == mel_input.shape[1] * c.audio.hop_length
|
||||||
|
), f"wavs.shape: {wavs.shape[1]}, mel_input.shape: {mel_input.shape[1] * c.audio.hop_length}"
|
||||||
|
|
||||||
|
# make sure that the computed mels and the waveform match and correctly computed
|
||||||
|
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
|
||||||
|
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
|
||||||
|
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
|
||||||
|
assert abs(mel_diff.sum()) < 1e-5
|
||||||
|
|
||||||
# check normalization ranges
|
# check normalization ranges
|
||||||
if self.ap.symmetric_norm:
|
if self.ap.symmetric_norm:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
assert mel_input.max() <= self.ap.max_norm
|
||||||
|
|
|
@ -27,6 +27,7 @@ config = AlignTTSConfig(
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
config.audio.do_trim_silence = True
|
config.audio.do_trim_silence = True
|
||||||
config.audio.trim_db = 60
|
config.audio.trim_db = 60
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
Loading…
Reference in New Issue