mirror of https://github.com/coqui-ai/TTS.git
Update GlowTTS
This commit is contained in:
parent
4baecdf92a
commit
fd95926009
|
@ -106,7 +106,6 @@ class InvConvNear(nn.Module):
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
- x_mask: :math:`[B, 1, T]`
|
- x_mask: :math:`[B, 1, T]`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
b, c, t = x.size()
|
b, c, t = x.size()
|
||||||
assert c % self.num_splits == 0
|
assert c % self.num_splits == 0
|
||||||
if x_mask is None:
|
if x_mask is None:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -47,7 +48,7 @@ class GlowTTS(BaseTTS):
|
||||||
|
|
||||||
def __init__(self, config: GlowTTSConfig):
|
def __init__(self, config: GlowTTSConfig):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
# for fewer code change
|
# for fewer code change
|
||||||
|
@ -387,7 +388,7 @@ class GlowTTS(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
def _create_logs(self, batch, outputs, ap):
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -416,15 +417,26 @@ class GlowTTS(BaseTTS):
|
||||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
return figures, {"audio": train_audio}
|
return figures, {"audio": train_audio}
|
||||||
|
|
||||||
|
def train_log(
|
||||||
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
|
) -> None: # pylint: disable=no-self-use
|
||||||
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.train_figures(steps, figures)
|
||||||
|
logger.train_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
return self.train_log(ap, batch, outputs)
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self, ap):
|
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -432,6 +444,7 @@ class GlowTTS(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
|
ap = assets["audio_processor"]
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
|
from TTS.tts.configs import BaseDatasetConfig, GlowTTSConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.glow_tts import GlowTTS
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
|
@ -25,6 +28,24 @@ config = GlowTTSConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = GlowTTS(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue