mirror of https://github.com/coqui-ai/TTS.git
Update VITS
This commit is contained in:
parent
4f94f91305
commit
45889804c2
|
@ -217,7 +217,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.END2END = True
|
self.END2END = True
|
||||||
|
|
||||||
|
@ -576,22 +576,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(
|
def _log(self, ap, batch, outputs, name_prefix="train"):
|
||||||
self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
|
|
||||||
): # pylint: disable=no-self-use
|
|
||||||
"""Create visualizations and waveform examples.
|
|
||||||
|
|
||||||
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
|
||||||
be projected onto Tensorboard.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ap (AudioProcessor): audio processor used at training.
|
|
||||||
batch (Dict): Model inputs used at the previous training step.
|
|
||||||
outputs (Dict): Model outputs generated at the previoud training step.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
|
||||||
"""
|
|
||||||
y_hat = outputs[0]["model_outputs"]
|
y_hat = outputs[0]["model_outputs"]
|
||||||
y = outputs[0]["waveform_seg"]
|
y = outputs[0]["waveform_seg"]
|
||||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||||
|
@ -609,12 +594,32 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
return figures, audios
|
return figures, audios
|
||||||
|
|
||||||
|
def train_log(
|
||||||
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
|
): # pylint: disable=no-self-use
|
||||||
|
"""Create visualizations and waveform examples.
|
||||||
|
|
||||||
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||||
|
be projected onto Tensorboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ap (AudioProcessor): audio processor used at training.
|
||||||
|
batch (Dict): Model inputs used at the previous training step.
|
||||||
|
outputs (Dict): Model outputs generated at the previoud training step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||||
|
"""
|
||||||
|
ap = assets["audio_processor"]
|
||||||
|
self._log(ap, batch, outputs, "train")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||||
return self.train_step(batch, criterion, optimizer_idx)
|
return self.train_step(batch, criterion, optimizer_idx)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
return self.train_log(ap, batch, outputs, "eval")
|
ap = assets["audio_processor"]
|
||||||
|
return self._log(ap, batch, outputs, "eval")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
|
from TTS.tts.configs import BaseDatasetConfig, VitsConfig
|
||||||
|
from TTS.tts.models.vits import Vits
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
|
@ -24,6 +28,7 @@ audio_config = BaseAudioConfig(
|
||||||
signal_norm=False,
|
signal_norm=False,
|
||||||
do_amp_to_db_linear=False,
|
do_amp_to_db_linear=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
run_name="vits_ljspeech",
|
run_name="vits_ljspeech",
|
||||||
|
@ -47,6 +52,24 @@ config = VitsConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=True)
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = Vits(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue