Merge branch 'main' into dev

Merge changes by JMa into dev
This commit is contained in:
Jindrich Matousek 2023-09-21 08:10:47 +02:00
commit 0a89a43a77
18 changed files with 237 additions and 22 deletions

3
.gitignore vendored
View File

@ -169,3 +169,6 @@ wandb
depot/*
coqui_recipes/*
local_scripts/*
# SVN
.svn/

0
TTS/bin/collect_env_info.py Normal file → Executable file
View File

0
TTS/bin/compute_attention_masks.py Normal file → Executable file
View File

0
TTS/bin/compute_embeddings.py Normal file → Executable file
View File

0
TTS/bin/eval_encoder.py Normal file → Executable file
View File

0
TTS/bin/find_unique_chars.py Normal file → Executable file
View File

0
TTS/bin/find_unique_phonemes.py Normal file → Executable file
View File

0
TTS/bin/resample.py Normal file → Executable file
View File

0
TTS/bin/train_encoder.py Normal file → Executable file
View File

0
TTS/bin/train_tts.py Normal file → Executable file
View File

0
TTS/bin/train_vocoder.py Normal file → Executable file
View File

0
TTS/bin/tune_wavegrad.py Normal file → Executable file
View File

View File

@ -57,13 +57,23 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
def add_extra_keys(metadata, language, dataset_name):
changes = {}
for item in metadata:
# add language name
item["language"] = language
# JMa: Add language name only if not defined at the sample level. Could be good for multi-language datasets.
if not item["language"]:
item["language"] = language
# JMa: Prepend dataset name to speaker name. Could be good for multispeaker datasets.
if dataset_name and item["speaker_name"] != dataset_name and not item["speaker_name"].startswith(dataset_name+"_"):
changes[item["speaker_name"]] = f'{dataset_name}_{item["speaker_name"]}'
item["speaker_name"] = f'{dataset_name}_{item["speaker_name"]}'
# add unique audio name
relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0]
audio_unique_name = f"{dataset_name}#{relfilepath}"
item["audio_unique_name"] = audio_unique_name
# JMa: print changed speaker names if any
if changes:
for k, v in changes.items():
print(f" | > speaker name changed: {k} --> {v}")
return metadata

View File

@ -442,6 +442,54 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
return items
# JMa: VCTK with wav files (not flac)
def vctk_wav(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None):
"""VCTK dataset v0.92.
URL:
https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```.
It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset.
mic1:
Audio recorded using an omni-directional microphone (DPA 4035).
Contains very low frequency noises.
This is the same audio released in previous versions of VCTK:
https://doi.org/10.7488/ds/1994
mic2:
Audio recorded using a small diaphragm condenser microphone with
very wide bandwidth (Sennheiser MKH 800).
Two speakers, p280 and p315 had technical issues of the audio
recordings using MKH 800.
"""
file_ext = "wav"
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
# p280 has no mic2 recordings
if speaker_id == "p280":
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}")
else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file):
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else:
print(f" [!] wav files don't exist - {wav_file}")
return items
def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
@ -628,6 +676,71 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
return items
def artic(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Normalizes the ARTIC meta data file to TTS format
Args:
root_path (str): path to the artic dataset
meta_file (str): name of the meta file containing names of wav to select and
transcripts of the corresponding utterances
Returns:
List[List[str]]: List of (text, wav_path, speaker_name, language, root_path) associated with each utterance
"""
txt_file = os.path.join(root_path, meta_file)
items = []
# Speaker name is the name of the directory with the data (last part of `root_path`)
speaker_name = os.path.basename(os.path.normpath(root_path))
# Speaker name can consists of language code (eg. cs-CZ or en) and gender (m/f) separated by dots
# Example: AndJa.cs-CZ.m, LJS.en.f
try:
voice, lang, sex = speaker_name.split(".")
except ValueError:
voice = speaker_name
lang, sex = None, None
print(f" > ARTIC dataset: voice={voice}, sex={sex}, language={lang}")
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
# Check the number of standard separators
n_seps = line.count("|")
if n_seps > 0:
# Split according to standard separator
cols = line.split("|")
else:
# Assume ARTIC SNT format => wav name is delimited by the first space
cols = line.split(maxsplit=1)
# In either way, wav name is stored in `cols[0]` and text in `cols[-1]`
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[-1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "language": lang, "root_path": root_path})
return items
def artic_multispeaker(root_path, meta_file, ignored_speakers=None): # pylint: disable=unused-argument
"""Normalizes the ARTIC multi-speaker meta data files to TTS format
Args:
root_path (str): path to the artic dataset
meta_file (str): name of the meta file containing names of wav to select and
transcripts of the corresponding utterances
!Must be the same for all speakers!
ignore_speakers (List[str]): list of ignored speakers (or None)
Returns:
List[List[str]]: List of (text, wav_path, speaker_name) associated with each utterance
"""
items = []
# Loop over speakers: speaker names are subdirs of `root_path`
for pth in glob(f"{root_path}/*/", recursive=False):
speaker_name = os.path.basename(pth)
# Ignore speakers
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.extend(artic(pth, meta_file))
return items
def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
"""Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset"""
txt_file = os.path.join(root_path, meta_file)

View File

@ -1084,12 +1084,47 @@ class Vits(BaseTTS):
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
# JMa: set minimum duration if predicted duration is lower than threshold
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
# @staticmethod
# def _set_min_inference_length(d, threshold):
# d_mask = d < threshold
# d[d_mask] = threshold
# return d
def _set_min_inference_length(self, x, durs, threshold):
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
# Get list of tokens from IDs
tokens = x.squeeze().tolist()
# Check current and next token
n = self.tokenizer.characters.id_to_char(tokens[0])
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
for ix, idx in enumerate(tokens[1:]):
# c = self.tokenizer.characters.id_to_char(id_c)
c = n
n = self.tokenizer.characters.id_to_char(idx)
if c in punctlike:
# Skip thresholding for punctuation
continue
# Add duration from next punctuation if possible
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
# Threshold duration if duration lower than threshold
if d < threshold:
durs[:,:,ix] = threshold
return durs
@torch.no_grad()
def inference(
self,
x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
aux_input={"x_lengths": None,
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"durations": None,
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
},
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1100,6 +1135,8 @@ class Vits(BaseTTS):
- x_lengths: :math:`[B]`
- d_vectors: :math:`[B, C]`
- speaker_ids: :math:`[B]`
- durations: :math: `[B, T_seq]`
- length_scale: :math: `[B, T_seq]` or `[B]`
Return Shapes:
- model_outputs: :math:`[B, 1, T_wav]`
@ -1109,6 +1146,9 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
# JMa: Save input
x_input = x
sid, g, lid, durations = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
@ -1137,8 +1177,28 @@ class Vits(BaseTTS):
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale
# JMa: set minimum duration if required
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
# JMa: length scale for the given sentence-like input
# ORIG: w = torch.exp(logw) * x_mask * self.length_scale
# If `length_scale` is in `aux_input`, it resets the default value given by `self.length_scale`,
# otherwise the default `self.length_scale` from `config.json` is used.
length_scale = aux_input.get("length_scale", self.length_scale)
# JMa: `length_scale` is used to scale duration relatively to the predicted values, it should be:
# - float (or int) => duration of the output speech will be linearly scaled
# - torch vector `[B, T_seq]`` (`B`` is batch size, `T_seq`` is the length of the input symbols)
# => each input symbol (phone or char) is scaled according to the corresponding value in the torch vector
if isinstance(length_scale, float) or isinstance(length_scale, int):
w *= length_scale
else:
assert length_scale.shape[-1] == w.shape[-1]
w *= length_scale.unsqueeze(0)
else:
# To force absolute durations (in frames), "durations" has to be in `aux_input`.
# The durations should be a torch vector [B, N] (`B`` is batch size, `T_seq`` is the length of the input symbols)
# => each input symbol (phone or char) will have the duration given by the corresponding value (number of frames) in the torch vector
assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0)
@ -1439,7 +1499,8 @@ class Vits(BaseTTS):
test_sentences = self.config.test_sentences
for idx, s_info in enumerate(test_sentences):
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
wav, alignment, _, _ = synthesis(
# JMa: replace individual variables with dictionary
outputs = synthesis(
self,
aux_inputs["text"],
self.config,
@ -1450,9 +1511,9 @@ class Vits(BaseTTS):
language_id=aux_inputs["language_id"],
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
)
test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"].T, output_fig=False)
return {"figures": test_figures, "audios": test_audios}
def test_log(
@ -1618,7 +1679,8 @@ class Vits(BaseTTS):
dataset.preprocess_samples()
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)
# JMa: Add `is_eval` parameter because the default is `False` and `batch_size` was used instead of `eval_batch_size`
sampler = self.get_sampler(config, dataset, num_gpus, is_eval)
if sampler is None:
loader = DataLoader(
dataset,

View File

@ -24,6 +24,7 @@ def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
return style_mel
# JMa: add `aux_input` to enable extra input (length_scale, durations)
def run_model_torch(
model: nn.Module,
inputs: torch.Tensor,
@ -32,6 +33,7 @@ def run_model_torch(
style_text: str = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
aux_input: Dict = {},
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -50,17 +52,19 @@ def run_model_torch(
_func = model.module.inference
else:
_func = model.inference
outputs = _func(
inputs,
aux_input={
"x_lengths": input_lengths,
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"style_text": style_text,
"language_ids": language_id,
},
)
# JMa: propagate other inputs like `durations``, `length_scale``, and `min_input_length`
# to `aux_input` to enable changing length (durations) per each input text (sentence)
# and to set minimum allowed length of each input char/phoneme
# - `length_scale` changes length of the whole generated wav
# - `durations` sets up duration (in frames) for each input text ID
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
aux_input["x_lengths"] = input_lengths
aux_input["speaker_ids"] = speaker_id
aux_input["d_vectors"] = d_vector
aux_input["style_mel"] = style_mel
aux_input["style_text"] = style_text
aux_input["language_ids"] = language_id
outputs = _func(inputs, aux_input)
return outputs
@ -113,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
return wavs
# JMa: add `aux_input` to enable extra input (like length_scale, durations)
def synthesis(
model,
text,
@ -125,6 +130,7 @@ def synthesis(
do_trim_silence=False,
d_vector=None,
language_id=None,
aux_input={},
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -217,6 +223,7 @@ def synthesis(
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
model,
@ -226,10 +233,14 @@ def synthesis(
style_text,
d_vector=d_vector,
language_id=language_id,
# JMa: add `aux_input` to enable extra input (length_scale, durations)
aux_input=aux_input,
)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
# JMa: extract durations
durations = outputs.get("durations", None)
# convert outputs to numpy
# plot results
@ -248,6 +259,8 @@ def synthesis(
"alignments": alignments,
"text_inputs": text_inputs,
"outputs": outputs,
# JMa: return durations
"durations": durations,
}
return return_dict

View File

@ -233,7 +233,21 @@ class Synthesizer(nn.Module):
Returns:
List[str]: list of sentences.
"""
return self.seg.segment(text)
# JMa
if "!" in self.tts_config.characters.characters:
# Our proprietary phonetic mode enabled: the input text is assumed
# to be a sequence of phones plus punctuations (without "!") and pauses (#, $).
# (!) is a regular character, not a punctuation
# WA: Glottal stop [!] is temporarily replaced with [*] to prevent
# boundary detection.
#
# Example: "!ahoj, !adame." -> ["!ahoj, !", "adame."]
# Fix: "!ahoj, !adame." -> ["!ahoj, !adame."]
text = text.replace("!", "*")
sents = self.seg.segment(text)
return [s.replace("*", "!") for s in sents]
else: # Original code
return self.seg.segment(text)
def save_wav(self, wav: List[int], path: str) -> None:
"""Save the waveform as a file.

View File

@ -208,7 +208,7 @@ class GAN(BaseVocoder):
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
figures, audios = self._log("eval", self.ap, batch, outputs)
figures, audios = self._log("train", self.ap, batch, outputs)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)