Bug fixes in OverFlow audio generation (#2380)

This commit is contained in:
Roee Shenberg 2023-03-15 12:02:11 +01:00 committed by GitHub
parent b8d9837d27
commit 3c15f0619a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 8 deletions

View File

@ -179,6 +179,7 @@ class NeuralhmmTTS(BaseTTS):
Args:
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
"""
default_input_dict = default_input_dict.copy()
default_input_dict.update(
{
"sampling_temp": self.sampling_temp,
@ -187,8 +188,8 @@ class NeuralhmmTTS(BaseTTS):
}
)
if aux_input:
return format_aux_input(aux_input, default_input_dict)
return None
return format_aux_input(default_input_dict, aux_input)
return default_input_dict
@torch.no_grad()
def inference(
@ -319,7 +320,7 @@ class NeuralhmmTTS(BaseTTS):
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))

View File

@ -192,6 +192,7 @@ class Overflow(BaseTTS):
Args:
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
"""
default_input_dict = default_input_dict.copy()
default_input_dict.update(
{
"sampling_temp": self.sampling_temp,
@ -200,8 +201,8 @@ class Overflow(BaseTTS):
}
)
if aux_input:
return format_aux_input(aux_input, default_input_dict)
return None
return format_aux_input(default_input_dict, aux_input)
return default_input_dict
@torch.no_grad()
def inference(
@ -335,7 +336,7 @@ class Overflow(BaseTTS):
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))

View File

@ -167,9 +167,10 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
Returns:
Dict: arguments with formatted auxilary inputs.
"""
kwargs = kwargs.copy()
for name in def_args:
if name not in kwargs:
kwargs[def_args[name]] = None
if name not in kwargs or kwargs[name] is None:
kwargs[name] = def_args[name]
return kwargs