From 3c15f0619a3c87aca6ed99c614d930575ecb7ee1 Mon Sep 17 00:00:00 2001 From: Roee Shenberg Date: Wed, 15 Mar 2023 12:02:11 +0100 Subject: [PATCH] Bug fixes in OverFlow audio generation (#2380) --- TTS/tts/models/neuralhmm_tts.py | 7 ++++--- TTS/tts/models/overflow.py | 7 ++++--- TTS/utils/generic_utils.py | 5 +++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index e4f88452..e2414108 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -179,6 +179,7 @@ class NeuralhmmTTS(BaseTTS): Args: aux_inputs (Dict): Dictionary containing the auxiliary inputs. """ + default_input_dict = default_input_dict.copy() default_input_dict.update( { "sampling_temp": self.sampling_temp, @@ -187,8 +188,8 @@ class NeuralhmmTTS(BaseTTS): } ) if aux_input: - return format_aux_input(aux_input, default_input_dict) - return None + return format_aux_input(default_input_dict, aux_input) + return default_input_dict @torch.no_grad() def inference( @@ -319,7 +320,7 @@ class NeuralhmmTTS(BaseTTS): # sample one item from the batch -1 will give the smalles item print(" | > Synthesising audio from the model...") inference_output = self.inference( - batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)} + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index c2b5b7c2..92b3c767 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -192,6 +192,7 @@ class Overflow(BaseTTS): Args: aux_inputs (Dict): Dictionary containing the auxiliary inputs. """ + default_input_dict = default_input_dict.copy() default_input_dict.update( { "sampling_temp": self.sampling_temp, @@ -200,8 +201,8 @@ class Overflow(BaseTTS): } ) if aux_input: - return format_aux_input(aux_input, default_input_dict) - return None + return format_aux_input(default_input_dict, aux_input) + return default_input_dict @torch.no_grad() def inference( @@ -335,7 +336,7 @@ class Overflow(BaseTTS): # sample one item from the batch -1 will give the smalles item print(" | > Synthesising audio from the model...") inference_output = self.inference( - batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)} + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index b685210c..d5312f49 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -167,9 +167,10 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: Returns: Dict: arguments with formatted auxilary inputs. """ + kwargs = kwargs.copy() for name in def_args: - if name not in kwargs: - kwargs[def_args[name]] = None + if name not in kwargs or kwargs[name] is None: + kwargs[name] = def_args[name] return kwargs