From 8381379938b8e300c9112ff69ea28fc61002f4b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 31 May 2021 15:41:17 +0200 Subject: [PATCH] formating `cond_input` with a function in Tacotron models --- TTS/tts/models/tacotron.py | 2 ++ TTS/tts/models/tacotron2.py | 2 ++ TTS/tts/models/tacotron_abstract.py | 6 ++++++ TTS/utils/generic_utils.py | 17 +++++++++++++++++ 4 files changed, 27 insertions(+) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index da574c05..8d3124c3 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -191,6 +191,7 @@ class Tacotron(TacotronAbstract): mel_lengths: [B] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] """ + cond_input = self._format_cond_input(cond_input) outputs = {"alignments_backward": None, "decoder_outputs_backward": None} input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim @@ -250,6 +251,7 @@ class Tacotron(TacotronAbstract): @torch.no_grad() def inference(self, text_input, cond_input=None): + cond_input = self._format_cond_input(cond_input) inputs = self.embedding(text_input) encoder_outputs = self.encoder(inputs) if self.gst and self.use_gst: diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 14a838d7..bd1ad03e 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -191,6 +191,7 @@ class Tacotron2(TacotronAbstract): mel_lengths: [B] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] """ + cond_input = self._format_cond_input(cond_input) outputs = {"alignments_backward": None, "decoder_outputs_backward": None} # compute mask for padding # B x T_in_max (boolean) @@ -248,6 +249,7 @@ class Tacotron2(TacotronAbstract): @torch.no_grad() def inference(self, text, cond_input=None): + cond_input = self._format_cond_input(cond_input) embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 8eb7bf24..5e561066 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -1,10 +1,12 @@ import copy from abc import ABC, abstractmethod +from typing import Dict import torch from torch import nn from TTS.tts.utils.data import sequence_mask +from TTS.utils.generic_utils import format_cond_input from TTS.utils.training import gradual_training_scheduler @@ -94,6 +96,10 @@ class TacotronAbstract(ABC, nn.Module): self.decoder_backward = None self.coarse_decoder = None + @staticmethod + def _format_cond_input(cond_input: Dict) -> Dict: + return format_cond_input({"x_vectors": None, "speaker_ids": None}, cond_input) + ############################# # INIT FUNCTIONS ############################# diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index a562e86f..0c28116d 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -8,6 +8,7 @@ import shutil import subprocess import sys from pathlib import Path +from typing import Dict import torch @@ -126,6 +127,22 @@ def set_init_dict(model_dict, checkpoint_state, c): return model_dict +def format_cond_input(def_args: Dict, kwargs: Dict) -> Dict: + """Format kwargs to hande auxilary inputs to models. + + Args: + def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. + kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. + + Returns: + Dict: arguments with formatted auxilary inputs. + """ + for name in def_args: + if name not in kwargs: + kwargs[def_args[name]] = None + return kwargs + + class KeepAverage: def __init__(self): self.avg_values = {}