Update GlowTTS

This commit is contained in:
Eren Gölge 2021-09-30 14:21:02 +00:00
parent 4baecdf92a
commit fd95926009
3 changed files with 42 additions and 9 deletions

View File

@ -106,7 +106,6 @@ class InvConvNear(nn.Module):
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]` - x_mask: :math:`[B, 1, T]`
""" """
b, c, t = x.size() b, c, t = x.size()
assert c % self.num_splits == 0 assert c % self.num_splits == 0
if x_mask is None: if x_mask is None:

View File

@ -1,4 +1,5 @@
import math import math
from typing import Dict, Tuple
import torch import torch
from torch import nn from torch import nn
@ -47,7 +48,7 @@ class GlowTTS(BaseTTS):
def __init__(self, config: GlowTTSConfig): def __init__(self, config: GlowTTSConfig):
super().__init__() super().__init__(config)
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
@ -387,7 +388,7 @@ class GlowTTS(BaseTTS):
) )
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use def _create_logs(self, batch, outputs, ap):
alignments = outputs["alignments"] alignments = outputs["alignments"]
text_input = batch["text_input"] text_input = batch["text_input"]
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
@ -416,15 +417,26 @@ class GlowTTS(BaseTTS):
train_audio = ap.inv_melspectrogram(pred_spec.T) train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio} return figures, {"audio": train_audio}
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
return self.train_log(ap, batch, outputs) ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
@torch.no_grad() @torch.no_grad()
def test_run(self, ap): def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -432,6 +444,7 @@ class GlowTTS(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}

View File

@ -1,7 +1,10 @@
import os import os
from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig( dataset_config = BaseDatasetConfig(
@ -25,6 +28,24 @@ config = GlowTTSConfig(
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) # init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init model
model = GlowTTS(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit() trainer.fit()