formating `cond_input` with a function in Tacotron models

This commit is contained in:
Eren Gölge 2021-05-31 15:41:17 +02:00
parent ef4ea9e527
commit 8381379938
4 changed files with 27 additions and 0 deletions

View File

@ -191,6 +191,7 @@ class Tacotron(TacotronAbstract):
mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
"""
cond_input = self._format_cond_input(cond_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim
@ -250,6 +251,7 @@ class Tacotron(TacotronAbstract):
@torch.no_grad()
def inference(self, text_input, cond_input=None):
cond_input = self._format_cond_input(cond_input)
inputs = self.embedding(text_input)
encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst:

View File

@ -191,6 +191,7 @@ class Tacotron2(TacotronAbstract):
mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
"""
cond_input = self._format_cond_input(cond_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
# compute mask for padding
# B x T_in_max (boolean)
@ -248,6 +249,7 @@ class Tacotron2(TacotronAbstract):
@torch.no_grad()
def inference(self, text, cond_input=None):
cond_input = self._format_cond_input(cond_input)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)

View File

@ -1,10 +1,12 @@
import copy
from abc import ABC, abstractmethod
from typing import Dict
import torch
from torch import nn
from TTS.tts.utils.data import sequence_mask
from TTS.utils.generic_utils import format_cond_input
from TTS.utils.training import gradual_training_scheduler
@ -94,6 +96,10 @@ class TacotronAbstract(ABC, nn.Module):
self.decoder_backward = None
self.coarse_decoder = None
@staticmethod
def _format_cond_input(cond_input: Dict) -> Dict:
return format_cond_input({"x_vectors": None, "speaker_ids": None}, cond_input)
#############################
# INIT FUNCTIONS
#############################

View File

@ -8,6 +8,7 @@ import shutil
import subprocess
import sys
from pathlib import Path
from typing import Dict
import torch
@ -126,6 +127,22 @@ def set_init_dict(model_dict, checkpoint_state, c):
return model_dict
def format_cond_input(def_args: Dict, kwargs: Dict) -> Dict:
"""Format kwargs to hande auxilary inputs to models.
Args:
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
Returns:
Dict: arguments with formatted auxilary inputs.
"""
for name in def_args:
if name not in kwargs:
kwargs[def_args[name]] = None
return kwargs
class KeepAverage:
def __init__(self):
self.avg_values = {}