mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
This commit is contained in:
commit
d00b91710a
|
@ -17,9 +17,12 @@ If you are new, you can also find [here](http://www.erogol.com/text-speech-deep-
|
|||
[Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
|
||||
|
||||
## Features
|
||||
- High performance Text2Speech models on Torch and Tensorflow 2.0.
|
||||
- High performance Speaker Encoder to compute speaker embeddings efficiently.
|
||||
- Integration with various Neural Vocoders (PWGAN, MelGAN, WaveRNN)
|
||||
- High performance Deep Learning models for Text2Speech related tasks.
|
||||
- Text2Speech models (Tacotron, Tacotron2).
|
||||
- Speaker Encoder to compute speaker embeddings efficiently.
|
||||
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS)
|
||||
- Support for multi-speaker TTS training.
|
||||
- Ability to convert Torch models to Tensorflow 2.0 for inference.
|
||||
- Released trained models.
|
||||
- Efficient training codes for PyTorch. (soon for Tensorflow 2.0)
|
||||
- Codes to convert Torch models to Tensorflow 2.0.
|
||||
|
|
|
@ -96,6 +96,8 @@
|
|||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||
|
||||
// STOPNET
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
|
|
|
@ -184,7 +184,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
||||
stopnet_output, stopnet_target, output_lens, decoder_b_output,
|
||||
alignments, alignment_lens, input_lens):
|
||||
alignments, alignment_lens, alignments_backwards, input_lens):
|
||||
|
||||
return_dict = {}
|
||||
# decoder and postnet losses
|
||||
|
@ -226,6 +226,15 @@ class TacotronLoss(torch.nn.Module):
|
|||
return_dict['decoder_b_loss'] = decoder_b_loss
|
||||
return_dict['decoder_c_loss'] = decoder_c_loss
|
||||
|
||||
# double decoder consistency loss (if enabled)
|
||||
if self.config.double_decoder_consistency:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
|
||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||
loss += decoder_b_loss + attention_c_loss
|
||||
return_dict['decoder_coarse_loss'] = decoder_b_loss
|
||||
return_dict['decoder_ddc_loss'] = attention_c_loss
|
||||
|
||||
# guided attention loss (if enabled)
|
||||
if self.config.ga_alpha > 0:
|
||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
import copy
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
from TTS.layers.gst_layers import GST
|
||||
from TTS.layers.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.models.tacotron_abstract import TacotronAbstract
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
class Tacotron(TacotronAbstract):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
memory_size=5,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
gst=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
|
@ -27,38 +25,41 @@ class Tacotron(nn.Module):
|
|||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 256
|
||||
encoder_dim = 512 if num_speakers > 1 else 256
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False,
|
||||
memory_size=5):
|
||||
super(Tacotron,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, gst)
|
||||
decoder_in_features = 512 if num_speakers > 1 else 256
|
||||
encoder_in_features = 512 if num_speakers > 1 else 256
|
||||
speaker_embedding_dim = 256
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
# base model layers
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
# boilerplate model
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win,
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||
postnet_output_dim)
|
||||
# speaker embedding layers
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.speaker_project_mel = nn.Sequential(
|
||||
nn.Linear(256, proj_speaker_dim), nn.Tanh())
|
||||
nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh())
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
# global style token layers
|
||||
|
@ -68,28 +69,15 @@ class Tacotron(nn.Module):
|
|||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self._init_coarse_decoder()
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(
|
||||
" [!] Model has speaker embedding layer but speaker_id is not provided"
|
||||
)
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self._compute_speaker_embedding(
|
||||
speaker_ids)
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None):
|
||||
"""
|
||||
Shapes:
|
||||
- characters: B x T_in
|
||||
|
@ -98,45 +86,59 @@ class Tacotron(nn.Module):
|
|||
- speaker_ids: B x 1
|
||||
"""
|
||||
self._init_states()
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x T_in x embed_dim
|
||||
inputs = self.embedding(characters)
|
||||
# B x speaker_embed_dim
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
inputs = self._concat_speaker_embedding(inputs,
|
||||
self.speaker_embeddings)
|
||||
# B x T_in x encoder_dim
|
||||
# B x T_in x encoder_in_features
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# sequence masking
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, self.speaker_embeddings)
|
||||
# decoder_outputs: B x decoder_dim x T_out
|
||||
# alignments: B x T_in x encoder_dim
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask,
|
||||
encoder_outputs, mel_specs, input_mask,
|
||||
self.speaker_embeddings_projected)
|
||||
# B x T_out x decoder_dim
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x T_out x decoder_in_features
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
|
||||
# B x T_out x posnet_dim
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
# B x T_out x decoder_dim
|
||||
# B x T_out x decoder_in_features
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||
inputs = self.embedding(characters)
|
||||
self._init_states()
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
inputs = self._concat_speaker_embedding(inputs,
|
||||
self.speaker_embeddings)
|
||||
|
@ -152,28 +154,3 @@ class Tacotron(nn.Module):
|
|||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _compute_speaker_embedding(self, speaker_ids):
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
return speaker_embeddings.unsqueeze_(1)
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import copy
|
||||
import torch
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
from TTS.layers.gst_layers import GST
|
||||
from TTS.layers.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.models.tacotron_abstract import TacotronAbstract
|
||||
|
||||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
class Tacotron2(TacotronAbstract):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
|
@ -25,16 +27,22 @@ class Tacotron2(nn.Module):
|
|||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.r = r
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 512
|
||||
encoder_dim = 512 if num_speakers > 1 else 512
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
super(Tacotron2,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, gst)
|
||||
decoder_in_features = 512 if num_speakers > 1 else 512
|
||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
# base layers
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
|
@ -42,20 +50,25 @@ class Tacotron2(nn.Module):
|
|||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
gst_embedding_dim = encoder_in_features
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self._init_coarse_decoder()
|
||||
|
||||
@staticmethod
|
||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
@ -63,31 +76,60 @@ class Tacotron2(nn.Module):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||
self._init_states()
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths).to(text.device)
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x D_embed x T_in_max
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
# B x T_in_max x D_en
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
# adding speaker embeddding to encoder output
|
||||
# TODO: multi-speaker
|
||||
# B x speaker_embed_dim
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if mel_lengths is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x mel_dim x T_out
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
|
@ -112,22 +154,16 @@ class Tacotron2(nn.Module):
|
|||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
|
||||
# TODO: multi-speaker
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
# raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
|
||||
speaker_embeddings.unsqueeze_(1)
|
||||
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||
encoder_outputs.size(1),
|
||||
-1)
|
||||
encoder_outputs = encoder_outputs + speaker_embeddings
|
||||
return encoder_outputs
|
||||
# speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||
# encoder_outputs.size(1),
|
||||
# -1)
|
||||
# encoder_outputs = encoder_outputs + speaker_embeddings
|
||||
# return encoder_outputs
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
""" Abstract Tacotron class """
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
self.ddc_r = ddc_r
|
||||
self.attn_type = attn_type
|
||||
self.attn_win = attn_win
|
||||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attn = location_attn
|
||||
self.attn_K = attn_K
|
||||
self.separate_stopnet = separate_stopnet
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# global style token
|
||||
if self.gst:
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
device = text_lengths.device
|
||||
input_mask = sequence_mask(text_lengths).to(device)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
""" Run backwards decoder """
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments,
|
||||
input_mask):
|
||||
""" Double Decoder Consistency """
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs,
|
||||
(0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
mode='nearest').transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
""" Compute speaker embedding vectors """
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(
|
||||
" [!] Model has speaker embedding layer but speaker_id is not provided"
|
||||
)
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
""" Compute global style token """
|
||||
# pylint: disable=not-callable
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
|
@ -55,6 +55,8 @@
|
|||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"use_gst": false,
|
||||
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
|
|
@ -51,7 +51,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec, speaker_ids)
|
||||
input, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -66,7 +66,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
@ -95,6 +95,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
|
||||
mel_lengths[-1] = 120
|
||||
stop_targets = torch.zeros(8, 120, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
|
@ -130,7 +131,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(10):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
10
train.py
10
train.py
|
@ -158,13 +158,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
optimizer_st.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder:
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the alignment lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
|
@ -176,7 +177,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, text_lengths)
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
if c.bidirectional_decoder:
|
||||
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
|
||||
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
||||
|
|
|
@ -160,13 +160,16 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
|
@ -178,7 +181,9 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
return model
|
||||
|
||||
class KeepAverage():
|
||||
|
@ -313,6 +318,8 @@ def check_config(c):
|
|||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# stopnet
|
||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
|
|
|
@ -77,6 +77,7 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
|||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
""" Multiscale STFT loss for multi band model outputs """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
|
@ -85,6 +86,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
|||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
return loss_fake
|
||||
|
@ -92,6 +94,7 @@ class MSEGLoss(nn.Module):
|
|||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
return loss_fake
|
||||
|
@ -104,6 +107,7 @@ class HingeGLoss(nn.Module):
|
|||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
|
@ -113,6 +117,7 @@ class MSEDLoss(nn.Module):
|
|||
|
||||
class HingeDLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
|
@ -121,6 +126,7 @@ class HingeDLoss(nn.Module):
|
|||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||
|
@ -193,8 +199,8 @@ class GeneratorLoss(nn.Module):
|
|||
|
||||
self.stft_loss_weight = C.stft_loss_weight
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
||||
|
||||
if C.use_stft_loss:
|
||||
|
|
|
@ -52,7 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=False,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
sampler=None,
|
||||
num_workers=c.num_val_loader_workers
|
||||
|
@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
y_hat = model_G(c_G)
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
y_hat_vis = y_hat # for visualization
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_hat_vis = y_hat
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
@ -171,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
|
@ -265,12 +270,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
model_losses=loss_dict)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step,
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step,
|
||||
'train')
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
|
@ -322,8 +327,12 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
D_out_fake = model_D(y_hat)
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
@ -354,7 +363,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
for key, value in loss_G_dict.items():
|
||||
update_eval_values['avg_' + key] = value.item()
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avgP_step_time'] = step_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
|
|
Loading…
Reference in New Issue