formating `cond_input` with a function in Tacotron models

This commit is contained in:
Eren Gölge 2021-05-31 15:41:17 +02:00
parent 254707c610
commit f568833d28
4 changed files with 27 additions and 0 deletions

View File

@ -191,6 +191,7 @@ class Tacotron(TacotronAbstract):
mel_lengths: [B] mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
""" """
cond_input = self._format_cond_input(cond_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None} outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim # B x T_in x embed_dim
@ -250,6 +251,7 @@ class Tacotron(TacotronAbstract):
@torch.no_grad() @torch.no_grad()
def inference(self, text_input, cond_input=None): def inference(self, text_input, cond_input=None):
cond_input = self._format_cond_input(cond_input)
inputs = self.embedding(text_input) inputs = self.embedding(text_input)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst: if self.gst and self.use_gst:

View File

@ -191,6 +191,7 @@ class Tacotron2(TacotronAbstract):
mel_lengths: [B] mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
""" """
cond_input = self._format_cond_input(cond_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None} outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
# compute mask for padding # compute mask for padding
# B x T_in_max (boolean) # B x T_in_max (boolean)
@ -248,6 +249,7 @@ class Tacotron2(TacotronAbstract):
@torch.no_grad() @torch.no_grad()
def inference(self, text, cond_input=None): def inference(self, text, cond_input=None):
cond_input = self._format_cond_input(cond_input)
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self.encoder.inference(embedded_inputs)

View File

@ -1,10 +1,12 @@
import copy import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict
import torch import torch
from torch import nn from torch import nn
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.utils.generic_utils import format_cond_input
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
@ -94,6 +96,10 @@ class TacotronAbstract(ABC, nn.Module):
self.decoder_backward = None self.decoder_backward = None
self.coarse_decoder = None self.coarse_decoder = None
@staticmethod
def _format_cond_input(cond_input: Dict) -> Dict:
return format_cond_input({"x_vectors": None, "speaker_ids": None}, cond_input)
############################# #############################
# INIT FUNCTIONS # INIT FUNCTIONS
############################# #############################

View File

@ -8,6 +8,7 @@ import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict
import torch import torch
@ -126,6 +127,22 @@ def set_init_dict(model_dict, checkpoint_state, c):
return model_dict return model_dict
def format_cond_input(def_args: Dict, kwargs: Dict) -> Dict:
"""Format kwargs to hande auxilary inputs to models.
Args:
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
Returns:
Dict: arguments with formatted auxilary inputs.
"""
for name in def_args:
if name not in kwargs:
kwargs[def_args[name]] = None
return kwargs
class KeepAverage: class KeepAverage:
def __init__(self): def __init__(self):
self.avg_values = {} self.avg_values = {}