mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'main' into dev
Merge changes by JMa into dev
This commit is contained in:
commit
0a89a43a77
|
@ -169,3 +169,6 @@ wandb
|
|||
depot/*
|
||||
coqui_recipes/*
|
||||
local_scripts/*
|
||||
|
||||
# SVN
|
||||
.svn/
|
||||
|
|
|
@ -57,13 +57,23 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
|||
|
||||
|
||||
def add_extra_keys(metadata, language, dataset_name):
|
||||
changes = {}
|
||||
for item in metadata:
|
||||
# add language name
|
||||
item["language"] = language
|
||||
# JMa: Add language name only if not defined at the sample level. Could be good for multi-language datasets.
|
||||
if not item["language"]:
|
||||
item["language"] = language
|
||||
# JMa: Prepend dataset name to speaker name. Could be good for multispeaker datasets.
|
||||
if dataset_name and item["speaker_name"] != dataset_name and not item["speaker_name"].startswith(dataset_name+"_"):
|
||||
changes[item["speaker_name"]] = f'{dataset_name}_{item["speaker_name"]}'
|
||||
item["speaker_name"] = f'{dataset_name}_{item["speaker_name"]}'
|
||||
# add unique audio name
|
||||
relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0]
|
||||
audio_unique_name = f"{dataset_name}#{relfilepath}"
|
||||
item["audio_unique_name"] = audio_unique_name
|
||||
# JMa: print changed speaker names if any
|
||||
if changes:
|
||||
for k, v in changes.items():
|
||||
print(f" | > speaker name changed: {k} --> {v}")
|
||||
return metadata
|
||||
|
||||
|
||||
|
|
|
@ -442,6 +442,54 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
|||
return items
|
||||
|
||||
|
||||
# JMa: VCTK with wav files (not flac)
|
||||
def vctk_wav(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None):
|
||||
"""VCTK dataset v0.92.
|
||||
|
||||
URL:
|
||||
https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip
|
||||
|
||||
This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```.
|
||||
It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset.
|
||||
|
||||
mic1:
|
||||
Audio recorded using an omni-directional microphone (DPA 4035).
|
||||
Contains very low frequency noises.
|
||||
This is the same audio released in previous versions of VCTK:
|
||||
https://doi.org/10.7488/ds/1994
|
||||
|
||||
mic2:
|
||||
Audio recorded using a small diaphragm condenser microphone with
|
||||
very wide bandwidth (Sennheiser MKH 800).
|
||||
Two speakers, p280 and p315 had technical issues of the audio
|
||||
recordings using MKH 800.
|
||||
"""
|
||||
file_ext = "wav"
|
||||
items = []
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
# p280 has no mic2 recordings
|
||||
if speaker_id == "p280":
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}")
|
||||
else:
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||
if os.path.exists(wav_file):
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
return items
|
||||
|
||||
|
||||
def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
items = []
|
||||
|
@ -628,6 +676,71 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
return items
|
||||
|
||||
|
||||
def artic(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Normalizes the ARTIC meta data file to TTS format
|
||||
|
||||
Args:
|
||||
root_path (str): path to the artic dataset
|
||||
meta_file (str): name of the meta file containing names of wav to select and
|
||||
transcripts of the corresponding utterances
|
||||
|
||||
Returns:
|
||||
List[List[str]]: List of (text, wav_path, speaker_name, language, root_path) associated with each utterance
|
||||
"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
# Speaker name is the name of the directory with the data (last part of `root_path`)
|
||||
speaker_name = os.path.basename(os.path.normpath(root_path))
|
||||
# Speaker name can consists of language code (eg. cs-CZ or en) and gender (m/f) separated by dots
|
||||
# Example: AndJa.cs-CZ.m, LJS.en.f
|
||||
try:
|
||||
voice, lang, sex = speaker_name.split(".")
|
||||
except ValueError:
|
||||
voice = speaker_name
|
||||
lang, sex = None, None
|
||||
print(f" > ARTIC dataset: voice={voice}, sex={sex}, language={lang}")
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
# Check the number of standard separators
|
||||
n_seps = line.count("|")
|
||||
if n_seps > 0:
|
||||
# Split according to standard separator
|
||||
cols = line.split("|")
|
||||
else:
|
||||
# Assume ARTIC SNT format => wav name is delimited by the first space
|
||||
cols = line.split(maxsplit=1)
|
||||
# In either way, wav name is stored in `cols[0]` and text in `cols[-1]`
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[-1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "language": lang, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
def artic_multispeaker(root_path, meta_file, ignored_speakers=None): # pylint: disable=unused-argument
|
||||
"""Normalizes the ARTIC multi-speaker meta data files to TTS format
|
||||
|
||||
Args:
|
||||
root_path (str): path to the artic dataset
|
||||
meta_file (str): name of the meta file containing names of wav to select and
|
||||
transcripts of the corresponding utterances
|
||||
!Must be the same for all speakers!
|
||||
ignore_speakers (List[str]): list of ignored speakers (or None)
|
||||
|
||||
Returns:
|
||||
List[List[str]]: List of (text, wav_path, speaker_name) associated with each utterance
|
||||
"""
|
||||
items = []
|
||||
# Loop over speakers: speaker names are subdirs of `root_path`
|
||||
for pth in glob(f"{root_path}/*/", recursive=False):
|
||||
speaker_name = os.path.basename(pth)
|
||||
# Ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.extend(artic(pth, meta_file))
|
||||
return items
|
||||
|
||||
|
||||
def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||
"""Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
|
|
|
@ -1084,12 +1084,47 @@ class Vits(BaseTTS):
|
|||
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
|
||||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
# JMa: set minimum duration if predicted duration is lower than threshold
|
||||
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
|
||||
# @staticmethod
|
||||
# def _set_min_inference_length(d, threshold):
|
||||
# d_mask = d < threshold
|
||||
# d[d_mask] = threshold
|
||||
# return d
|
||||
|
||||
def _set_min_inference_length(self, x, durs, threshold):
|
||||
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
|
||||
# Get list of tokens from IDs
|
||||
tokens = x.squeeze().tolist()
|
||||
# Check current and next token
|
||||
n = self.tokenizer.characters.id_to_char(tokens[0])
|
||||
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
|
||||
for ix, idx in enumerate(tokens[1:]):
|
||||
# c = self.tokenizer.characters.id_to_char(id_c)
|
||||
c = n
|
||||
n = self.tokenizer.characters.id_to_char(idx)
|
||||
if c in punctlike:
|
||||
# Skip thresholding for punctuation
|
||||
continue
|
||||
# Add duration from next punctuation if possible
|
||||
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
|
||||
# Threshold duration if duration lower than threshold
|
||||
if d < threshold:
|
||||
durs[:,:,ix] = threshold
|
||||
return durs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||
aux_input={"x_lengths": None,
|
||||
"d_vectors": None,
|
||||
"speaker_ids": None,
|
||||
"language_ids": None,
|
||||
"durations": None,
|
||||
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
|
||||
},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1100,6 +1135,8 @@ class Vits(BaseTTS):
|
|||
- x_lengths: :math:`[B]`
|
||||
- d_vectors: :math:`[B, C]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
- durations: :math: `[B, T_seq]`
|
||||
- length_scale: :math: `[B, T_seq]` or `[B]`
|
||||
|
||||
Return Shapes:
|
||||
- model_outputs: :math:`[B, 1, T_wav]`
|
||||
|
@ -1109,6 +1146,9 @@ class Vits(BaseTTS):
|
|||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
# JMa: Save input
|
||||
x_input = x
|
||||
|
||||
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
|
@ -1137,8 +1177,28 @@ class Vits(BaseTTS):
|
|||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
# JMa: set minimum duration if required
|
||||
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
|
||||
|
||||
# JMa: length scale for the given sentence-like input
|
||||
# ORIG: w = torch.exp(logw) * x_mask * self.length_scale
|
||||
# If `length_scale` is in `aux_input`, it resets the default value given by `self.length_scale`,
|
||||
# otherwise the default `self.length_scale` from `config.json` is used.
|
||||
length_scale = aux_input.get("length_scale", self.length_scale)
|
||||
# JMa: `length_scale` is used to scale duration relatively to the predicted values, it should be:
|
||||
# - float (or int) => duration of the output speech will be linearly scaled
|
||||
# - torch vector `[B, T_seq]`` (`B`` is batch size, `T_seq`` is the length of the input symbols)
|
||||
# => each input symbol (phone or char) is scaled according to the corresponding value in the torch vector
|
||||
if isinstance(length_scale, float) or isinstance(length_scale, int):
|
||||
w *= length_scale
|
||||
else:
|
||||
assert length_scale.shape[-1] == w.shape[-1]
|
||||
w *= length_scale.unsqueeze(0)
|
||||
|
||||
else:
|
||||
# To force absolute durations (in frames), "durations" has to be in `aux_input`.
|
||||
# The durations should be a torch vector [B, N] (`B`` is batch size, `T_seq`` is the length of the input symbols)
|
||||
# => each input symbol (phone or char) will have the duration given by the corresponding value (number of frames) in the torch vector
|
||||
assert durations.shape[-1] == x.shape[-1]
|
||||
w = durations.unsqueeze(0)
|
||||
|
||||
|
@ -1439,7 +1499,8 @@ class Vits(BaseTTS):
|
|||
test_sentences = self.config.test_sentences
|
||||
for idx, s_info in enumerate(test_sentences):
|
||||
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
||||
wav, alignment, _, _ = synthesis(
|
||||
# JMa: replace individual variables with dictionary
|
||||
outputs = synthesis(
|
||||
self,
|
||||
aux_inputs["text"],
|
||||
self.config,
|
||||
|
@ -1450,9 +1511,9 @@ class Vits(BaseTTS):
|
|||
language_id=aux_inputs["language_id"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"].T, output_fig=False)
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
|
@ -1618,7 +1679,8 @@ class Vits(BaseTTS):
|
|||
dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
# JMa: Add `is_eval` parameter because the default is `False` and `batch_size` was used instead of `eval_batch_size`
|
||||
sampler = self.get_sampler(config, dataset, num_gpus, is_eval)
|
||||
if sampler is None:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -24,6 +24,7 @@ def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
|
|||
return style_mel
|
||||
|
||||
|
||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
||||
def run_model_torch(
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
|
@ -32,6 +33,7 @@ def run_model_torch(
|
|||
style_text: str = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
aux_input: Dict = {},
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -50,17 +52,19 @@ def run_model_torch(
|
|||
_func = model.module.inference
|
||||
else:
|
||||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
aux_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
"style_text": style_text,
|
||||
"language_ids": language_id,
|
||||
},
|
||||
)
|
||||
# JMa: propagate other inputs like `durations``, `length_scale``, and `min_input_length`
|
||||
# to `aux_input` to enable changing length (durations) per each input text (sentence)
|
||||
# and to set minimum allowed length of each input char/phoneme
|
||||
# - `length_scale` changes length of the whole generated wav
|
||||
# - `durations` sets up duration (in frames) for each input text ID
|
||||
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
|
||||
aux_input["x_lengths"] = input_lengths
|
||||
aux_input["speaker_ids"] = speaker_id
|
||||
aux_input["d_vectors"] = d_vector
|
||||
aux_input["style_mel"] = style_mel
|
||||
aux_input["style_text"] = style_text
|
||||
aux_input["language_ids"] = language_id
|
||||
outputs = _func(inputs, aux_input)
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -113,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
|||
return wavs
|
||||
|
||||
|
||||
# JMa: add `aux_input` to enable extra input (like length_scale, durations)
|
||||
def synthesis(
|
||||
model,
|
||||
text,
|
||||
|
@ -125,6 +130,7 @@ def synthesis(
|
|||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
language_id=None,
|
||||
aux_input={},
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
the vocoder model.
|
||||
|
@ -217,6 +223,7 @@ def synthesis(
|
|||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
outputs = run_model_torch(
|
||||
model,
|
||||
|
@ -226,10 +233,14 @@ def synthesis(
|
|||
style_text,
|
||||
d_vector=d_vector,
|
||||
language_id=language_id,
|
||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
||||
aux_input=aux_input,
|
||||
)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
# JMa: extract durations
|
||||
durations = outputs.get("durations", None)
|
||||
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
|
@ -248,6 +259,8 @@ def synthesis(
|
|||
"alignments": alignments,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
# JMa: return durations
|
||||
"durations": durations,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
|
|
@ -233,7 +233,21 @@ class Synthesizer(nn.Module):
|
|||
Returns:
|
||||
List[str]: list of sentences.
|
||||
"""
|
||||
return self.seg.segment(text)
|
||||
# JMa
|
||||
if "!" in self.tts_config.characters.characters:
|
||||
# Our proprietary phonetic mode enabled: the input text is assumed
|
||||
# to be a sequence of phones plus punctuations (without "!") and pauses (#, $).
|
||||
# (!) is a regular character, not a punctuation
|
||||
# WA: Glottal stop [!] is temporarily replaced with [*] to prevent
|
||||
# boundary detection.
|
||||
#
|
||||
# Example: "!ahoj, !adame." -> ["!ahoj, !", "adame."]
|
||||
# Fix: "!ahoj, !adame." -> ["!ahoj, !adame."]
|
||||
text = text.replace("!", "*")
|
||||
sents = self.seg.segment(text)
|
||||
return [s.replace("*", "!") for s in sents]
|
||||
else: # Original code
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str) -> None:
|
||||
"""Save the waveform as a file.
|
||||
|
|
|
@ -208,7 +208,7 @@ class GAN(BaseVocoder):
|
|||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
figures, audios = self._log("eval", self.ap, batch, outputs)
|
||||
figures, audios = self._log("train", self.ap, batch, outputs)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
|
Loading…
Reference in New Issue