diff --git a/.gitignore b/.gitignore index 563040e8..ee9ab0db 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,6 @@ wandb depot/* coqui_recipes/* local_scripts/* + +# SVN +.svn/ diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py old mode 100644 new mode 100755 diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py old mode 100644 new mode 100755 diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py old mode 100644 new mode 100755 diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py old mode 100644 new mode 100755 diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py old mode 100644 new mode 100755 diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py old mode 100644 new mode 100755 diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py old mode 100644 new mode 100755 diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py old mode 100644 new mode 100755 diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py old mode 100644 new mode 100755 diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py old mode 100644 new mode 100755 diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py old mode 100644 new mode 100755 diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 19213856..41d767a0 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -57,13 +57,23 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): def add_extra_keys(metadata, language, dataset_name): + changes = {} for item in metadata: - # add language name - item["language"] = language + # JMa: Add language name only if not defined at the sample level. Could be good for multi-language datasets. + if not item["language"]: + item["language"] = language + # JMa: Prepend dataset name to speaker name. Could be good for multispeaker datasets. + if dataset_name and item["speaker_name"] != dataset_name and not item["speaker_name"].startswith(dataset_name+"_"): + changes[item["speaker_name"]] = f'{dataset_name}_{item["speaker_name"]}' + item["speaker_name"] = f'{dataset_name}_{item["speaker_name"]}' # add unique audio name relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0] audio_unique_name = f"{dataset_name}#{relfilepath}" item["audio_unique_name"] = audio_unique_name + # JMa: print changed speaker names if any + if changes: + for k, v in changes.items(): + print(f" | > speaker name changed: {k} --> {v}") return metadata diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index fbf6881f..2c6041ba 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -442,6 +442,54 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic return items +# JMa: VCTK with wav files (not flac) +def vctk_wav(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "wav" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) + else: + print(f" [!] wav files don't exist - {wav_file}") + return items + + def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" items = [] @@ -628,6 +676,71 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument return items +def artic(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the ARTIC meta data file to TTS format + + Args: + root_path (str): path to the artic dataset + meta_file (str): name of the meta file containing names of wav to select and + transcripts of the corresponding utterances + + Returns: + List[List[str]]: List of (text, wav_path, speaker_name, language, root_path) associated with each utterance + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + # Speaker name is the name of the directory with the data (last part of `root_path`) + speaker_name = os.path.basename(os.path.normpath(root_path)) + # Speaker name can consists of language code (eg. cs-CZ or en) and gender (m/f) separated by dots + # Example: AndJa.cs-CZ.m, LJS.en.f + try: + voice, lang, sex = speaker_name.split(".") + except ValueError: + voice = speaker_name + lang, sex = None, None + print(f" > ARTIC dataset: voice={voice}, sex={sex}, language={lang}") + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + # Check the number of standard separators + n_seps = line.count("|") + if n_seps > 0: + # Split according to standard separator + cols = line.split("|") + else: + # Assume ARTIC SNT format => wav name is delimited by the first space + cols = line.split(maxsplit=1) + # In either way, wav name is stored in `cols[0]` and text in `cols[-1]` + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[-1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "language": lang, "root_path": root_path}) + return items + + +def artic_multispeaker(root_path, meta_file, ignored_speakers=None): # pylint: disable=unused-argument + """Normalizes the ARTIC multi-speaker meta data files to TTS format + + Args: + root_path (str): path to the artic dataset + meta_file (str): name of the meta file containing names of wav to select and + transcripts of the corresponding utterances + !Must be the same for all speakers! + ignore_speakers (List[str]): list of ignored speakers (or None) + + Returns: + List[List[str]]: List of (text, wav_path, speaker_name) associated with each utterance + """ + items = [] + # Loop over speakers: speaker names are subdirs of `root_path` + for pth in glob(f"{root_path}/*/", recursive=False): + speaker_name = os.path.basename(pth) + # Ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + items.extend(artic(pth, meta_file)) + return items + + def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument """Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset""" txt_file = os.path.join(root_path, meta_file) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d9b1f596..c11d9b96 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1084,12 +1084,47 @@ class Vits(BaseTTS): if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + + # JMa: set minimum duration if predicted duration is lower than threshold + # Workaround to avoid short durations that cause some chars/phonemes to be reduced + # @staticmethod + # def _set_min_inference_length(d, threshold): + # d_mask = d < threshold + # d[d_mask] = threshold + # return d + + def _set_min_inference_length(self, x, durs, threshold): + punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank] + # Get list of tokens from IDs + tokens = x.squeeze().tolist() + # Check current and next token + n = self.tokenizer.characters.id_to_char(tokens[0]) + # for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])): + for ix, idx in enumerate(tokens[1:]): + # c = self.tokenizer.characters.id_to_char(id_c) + c = n + n = self.tokenizer.characters.id_to_char(idx) + if c in punctlike: + # Skip thresholding for punctuation + continue + # Add duration from next punctuation if possible + d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix] + # Threshold duration if duration lower than threshold + if d < threshold: + durs[:,:,ix] = threshold + return durs @torch.no_grad() def inference( self, x, - aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + aux_input={"x_lengths": None, + "d_vectors": None, + "speaker_ids": None, + "language_ids": None, + "durations": None, + "min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length` + }, ): # pylint: disable=dangerous-default-value """ Note: @@ -1100,6 +1135,8 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - d_vectors: :math:`[B, C]` - speaker_ids: :math:`[B]` + - durations: :math: `[B, T_seq]` + - length_scale: :math: `[B, T_seq]` or `[B]` Return Shapes: - model_outputs: :math:`[B, 1, T_wav]` @@ -1109,6 +1146,9 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ + # JMa: Save input + x_input = x + sid, g, lid, durations = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) @@ -1137,8 +1177,28 @@ class Vits(BaseTTS): logw = self.duration_predictor( x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb ) - w = torch.exp(logw) * x_mask * self.length_scale + # JMa: set minimum duration if required + w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask + + # JMa: length scale for the given sentence-like input + # ORIG: w = torch.exp(logw) * x_mask * self.length_scale + # If `length_scale` is in `aux_input`, it resets the default value given by `self.length_scale`, + # otherwise the default `self.length_scale` from `config.json` is used. + length_scale = aux_input.get("length_scale", self.length_scale) + # JMa: `length_scale` is used to scale duration relatively to the predicted values, it should be: + # - float (or int) => duration of the output speech will be linearly scaled + # - torch vector `[B, T_seq]`` (`B`` is batch size, `T_seq`` is the length of the input symbols) + # => each input symbol (phone or char) is scaled according to the corresponding value in the torch vector + if isinstance(length_scale, float) or isinstance(length_scale, int): + w *= length_scale + else: + assert length_scale.shape[-1] == w.shape[-1] + w *= length_scale.unsqueeze(0) + else: + # To force absolute durations (in frames), "durations" has to be in `aux_input`. + # The durations should be a torch vector [B, N] (`B`` is batch size, `T_seq`` is the length of the input symbols) + # => each input symbol (phone or char) will have the duration given by the corresponding value (number of frames) in the torch vector assert durations.shape[-1] == x.shape[-1] w = durations.unsqueeze(0) @@ -1439,7 +1499,8 @@ class Vits(BaseTTS): test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( + # JMa: replace individual variables with dictionary + outputs = synthesis( self, aux_inputs["text"], self.config, @@ -1450,9 +1511,9 @@ class Vits(BaseTTS): language_id=aux_inputs["language_id"], use_griffin_lim=True, do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + ) + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"].T, output_fig=False) return {"figures": test_figures, "audios": test_audios} def test_log( @@ -1618,7 +1679,8 @@ class Vits(BaseTTS): dataset.preprocess_samples() # get samplers - sampler = self.get_sampler(config, dataset, num_gpus) + # JMa: Add `is_eval` parameter because the default is `False` and `batch_size` was used instead of `eval_batch_size` + sampler = self.get_sampler(config, dataset, num_gpus, is_eval) if sampler is None: loader = DataLoader( dataset, diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 797151c2..949bd2f9 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -24,6 +24,7 @@ def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): return style_mel +# JMa: add `aux_input` to enable extra input (length_scale, durations) def run_model_torch( model: nn.Module, inputs: torch.Tensor, @@ -32,6 +33,7 @@ def run_model_torch( style_text: str = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, + aux_input: Dict = {}, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -50,17 +52,19 @@ def run_model_torch( _func = model.module.inference else: _func = model.inference - outputs = _func( - inputs, - aux_input={ - "x_lengths": input_lengths, - "speaker_ids": speaker_id, - "d_vectors": d_vector, - "style_mel": style_mel, - "style_text": style_text, - "language_ids": language_id, - }, - ) + # JMa: propagate other inputs like `durations``, `length_scale``, and `min_input_length` + # to `aux_input` to enable changing length (durations) per each input text (sentence) + # and to set minimum allowed length of each input char/phoneme + # - `length_scale` changes length of the whole generated wav + # - `durations` sets up duration (in frames) for each input text ID + # - minimum allowed length (in frames) per input ID (char/phoneme) during inference + aux_input["x_lengths"] = input_lengths + aux_input["speaker_ids"] = speaker_id + aux_input["d_vectors"] = d_vector + aux_input["style_mel"] = style_mel + aux_input["style_text"] = style_text + aux_input["language_ids"] = language_id + outputs = _func(inputs, aux_input) return outputs @@ -113,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs +# JMa: add `aux_input` to enable extra input (like length_scale, durations) def synthesis( model, text, @@ -125,6 +130,7 @@ def synthesis( do_trim_silence=False, d_vector=None, language_id=None, + aux_input={}, ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -217,6 +223,7 @@ def synthesis( text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) + # synthesize voice outputs = run_model_torch( model, @@ -226,10 +233,14 @@ def synthesis( style_text, d_vector=d_vector, language_id=language_id, + # JMa: add `aux_input` to enable extra input (length_scale, durations) + aux_input=aux_input, ) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() alignments = outputs["alignments"] + # JMa: extract durations + durations = outputs.get("durations", None) # convert outputs to numpy # plot results @@ -248,6 +259,8 @@ def synthesis( "alignments": alignments, "text_inputs": text_inputs, "outputs": outputs, + # JMa: return durations + "durations": durations, } return return_dict diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 24a078f5..250e75eb 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -233,7 +233,21 @@ class Synthesizer(nn.Module): Returns: List[str]: list of sentences. """ - return self.seg.segment(text) + # JMa + if "!" in self.tts_config.characters.characters: + # Our proprietary phonetic mode enabled: the input text is assumed + # to be a sequence of phones plus punctuations (without "!") and pauses (#, $). + # (!) is a regular character, not a punctuation + # WA: Glottal stop [!] is temporarily replaced with [*] to prevent + # boundary detection. + # + # Example: "!ahoj, !adame." -> ["!ahoj, !", "adame."] + # Fix: "!ahoj, !adame." -> ["!ahoj, !adame."] + text = text.replace("!", "*") + sents = self.seg.segment(text) + return [s.replace("*", "!") for s in sents] + else: # Original code + return self.seg.segment(text) def save_wav(self, wav: List[int], path: str) -> None: """Save the waveform as a file. diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 19c30e98..8cb8c260 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -208,7 +208,7 @@ class GAN(BaseVocoder): self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument ) -> Tuple[Dict, np.ndarray]: """Call `_log()` for training.""" - figures, audios = self._log("eval", self.ap, batch, outputs) + figures, audios = self._log("train", self.ap, batch, outputs) logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate)