make style

This commit is contained in:
Eren Gölge 2021-05-28 13:37:08 +02:00
parent b643e8b37c
commit d96ebcd6d3
13 changed files with 130 additions and 130 deletions

View File

@ -8,10 +8,10 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from TTS.tts.models import setup_model
from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.models import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config from TTS.utils.io import load_config

View File

@ -1,9 +1,10 @@
import os import os
import sys import sys
import traceback import traceback
from TTS.trainer import TrainerTTS
from TTS.utils.arguments import init_training from TTS.utils.arguments import init_training
from TTS.utils.generic_utils import remove_experiment_folder from TTS.utils.generic_utils import remove_experiment_folder
from TTS.trainer import TrainerTTS
def main(): def main():

View File

@ -4,21 +4,19 @@ import importlib
import logging import logging
import os import os
import time import time
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from dataclasses import dataclass, field
from typing import Tuple, Dict, List, Union
from argparse import Namespace
# DISTRIBUTED # DISTRIBUTED
from torch import nn from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.tts.datasets import TTSDataset, load_meta_data from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
@ -30,49 +28,48 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.training import check_update, setup_torch_training_env from TTS.utils.training import check_update, setup_torch_training_env
@dataclass @dataclass
class TrainingArgs(Coqpit): class TrainingArgs(Coqpit):
continue_path: str = field( continue_path: str = field(
default='', default="",
metadata={ metadata={
'help': "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.' },
}) )
restore_path: str = field( restore_path: str = field(
default='', default="",
metadata={ metadata={
'help': "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.' },
}) )
best_path: str = field( best_path: str = field(
default='', default="",
metadata={ metadata={
'help': "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
"Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" },
}) )
config_path: str = field( config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
default='', metadata={'help': 'Path to the configuration file.'}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
rank: int = field( group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
default=0, metadata={'help': 'Process rank in distributed training.'})
group_id: str = field(
default='',
metadata={'help': 'Process group id in distributed training.'})
# pylint: disable=import-outside-toplevel, too-many-public-methods # pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerTTS: class TrainerTTS:
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(self, def __init__(
args: Union[Coqpit, Namespace], self,
config: Coqpit, args: Union[Coqpit, Namespace],
c_logger: ConsoleLogger, config: Coqpit,
tb_logger: TensorboardLogger, c_logger: ConsoleLogger,
model: nn.Module = None, tb_logger: TensorboardLogger,
output_path: str = None) -> None: model: nn.Module = None,
output_path: str = None,
) -> None:
self.args = args self.args = args
self.config = config self.config = config
self.c_logger = c_logger self.c_logger = c_logger
@ -90,8 +87,7 @@ class TrainerTTS:
self.keep_avg_train = None self.keep_avg_train = None
self.keep_avg_eval = None self.keep_avg_eval = None
log_file = os.path.join(self.output_path, log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file) self._setup_logger_config(log_file)
# model, audio processor, datasets, loss # model, audio processor, datasets, loss
@ -106,16 +102,19 @@ class TrainerTTS:
# default speaker manager # default speaker manager
self.speaker_manager = self.get_speaker_manager( self.speaker_manager = self.get_speaker_manager(
self.config, args.restore_path, self.config.output_path, self.data_train) self.config, args.restore_path, self.config.output_path, self.data_train
)
# init TTS model # init TTS model
if model is not None: if model is not None:
self.model = model self.model = model
else: else:
self.model = self.get_model( self.model = self.get_model(
len(self.model_characters), self.speaker_manager.num_speakers, len(self.model_characters),
self.config, self.speaker_manager.x_vector_dim self.speaker_manager.num_speakers,
if self.speaker_manager.x_vectors else None) self.config,
self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None,
)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.config) self.criterion = self.get_criterion(self.config)
@ -126,13 +125,16 @@ class TrainerTTS:
# DISTRUBUTED # DISTRUBUTED
if self.num_gpus > 1: if self.num_gpus > 1:
init_distributed(args.rank, self.num_gpus, args.group_id, init_distributed(
self.config.distributed["backend"], args.rank,
self.config.distributed["url"]) self.num_gpus,
args.group_id,
self.config.distributed["backend"],
self.config.distributed["url"],
)
# scalers for mixed precision training # scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler( self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
) if self.config.mixed_precision and self.use_cuda else None
# setup optimizer # setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config) self.optimizer = self.get_optimizer(self.model, self.config)
@ -154,8 +156,7 @@ class TrainerTTS:
print("\n > Model has {} parameters".format(num_params)) print("\n > Model has {} parameters".format(num_params))
@staticmethod @staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit, def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module:
x_vector_dim: int) -> nn.Module:
model = setup_model(num_chars, num_speakers, config, x_vector_dim) model = setup_model(num_chars, num_speakers, config, x_vector_dim)
return model return model
@ -182,26 +183,32 @@ class TrainerTTS:
return model_characters return model_characters
@staticmethod @staticmethod
def get_speaker_manager(config: Coqpit, def get_speaker_manager(
restore_path: str = "", config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
out_path: str = "", ) -> SpeakerManager:
data_train: List = []) -> SpeakerManager:
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
if restore_path: if restore_path:
speakers_file = os.path.join(os.path.dirname(restore_path), speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
"speaker.json")
if not os.path.exists(speakers_file): if not os.path.exists(speakers_file):
print( print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
) )
speakers_file = config.external_speaker_embedding_file speakers_file = config.external_speaker_embedding_file
if config.use_external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(speakers_file)
else:
speaker_manager.load_ids_file(speakers_file)
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path) speaker_manager.save_ids_file(file_path)
return speaker_manager return speaker_manager
@staticmethod @staticmethod
def get_scheduler(config: Coqpit, def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
lr_scheduler = config.lr_scheduler lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None: if lr_scheduler is None:
@ -224,7 +231,7 @@ class TrainerTTS:
restore_path: str, restore_path: str,
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scaler: torch.cuda.amp.GradScaler = None scaler: torch.cuda.amp.GradScaler = None,
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
print(" > Restoring from %s ..." % os.path.basename(restore_path)) print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path) checkpoint = torch.load(restore_path)
@ -245,13 +252,21 @@ class TrainerTTS:
for group in optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = self.config.lr group["lr"] = self.config.lr
print(" > Model restored from step %d" % checkpoint["step"], ) print(
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"] restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step return model, optimizer, scaler, restore_step
def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool, def _get_loader(
data_items: List, verbose: bool, self,
speaker_mapping: Union[Dict, List]) -> DataLoader: r: int,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
speaker_mapping: Union[Dict, List],
) -> DataLoader:
if is_eval and not self.config.run_eval: if is_eval and not self.config.run_eval:
loader = None loader = None
else: else:
@ -295,17 +310,15 @@ class TrainerTTS:
) )
return loader return loader
def get_train_dataloader(self, r: int, ap: AudioProcessor, def get_train_dataloader(
data_items: List, verbose: bool, self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
speaker_mapping: Union[List, Dict]) -> DataLoader: ) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
speaker_mapping)
def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List, def get_eval_dataloder(
verbose: bool, self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
speaker_mapping: Union[List, Dict]) -> DataLoader: ) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
speaker_mapping)
def format_batch(self, batch: List) -> Dict: def format_batch(self, batch: List) -> Dict:
# setup input batch # setup input batch
@ -390,8 +403,7 @@ class TrainerTTS:
"item_idx": item_idx, "item_idx": item_idx,
} }
def train_step(self, batch: Dict, batch_n_steps: int, step: int, def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start() self.on_train_step_start()
step_start_time = time.time() step_start_time = time.time()
@ -560,7 +572,9 @@ class TrainerTTS:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
def test_run(self, ) -> None: def test_run(
self,
) -> None:
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -581,28 +595,26 @@ class TrainerTTS:
do_trim_silence=False, do_trim_silence=False,
).values() ).values()
file_path = os.path.join(self.output_audio_path, file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
"TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path) self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_cond_inputs(self) -> Dict: def _get_cond_inputs(self) -> Dict:
# setup speaker_id # setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None speaker_id = 0 if self.config.use_speaker_embedding else None
# setup x_vector # setup x_vector
x_vector = (self.speaker_manager.get_x_vectors_by_speaker( x_vector = (
self.speaker_manager.speaker_ids[0]) self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0])
if self.config.use_external_speaker_embedding_file if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
and self.config.use_speaker_embedding else None) else None
)
# setup style_mel # setup style_mel
if self.config.has("gst_style_input"): if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input style_wav = self.config.gst_style_input
@ -611,40 +623,29 @@ class TrainerTTS:
if style_wav is None and "use_gst" in self.config and self.config.use_gst: if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict. # inicialize GST with zero dict.
style_wav = {} style_wav = {}
print( print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
)
for i in range(self.config.gst["gst_num_style_tokens"]): for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0 style_wav[str(i)] = 0
cond_inputs = { cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector}
"speaker_id": speaker_id,
"style_wav": style_wav,
"x_vector": x_vector
}
return cond_inputs return cond_inputs
def fit(self) -> None: def fit(self) -> None:
if self.restore_step != 0 or self.args.best_path: if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from " print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
f"{os.path.basename(self.args.best_path)} ...") self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
self.best_loss = torch.load(self.args.best_path,
map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.") print(f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders # define data loaders
self.train_loader = self.get_train_dataloader( self.train_loader = self.get_train_dataloader(
self.config.r, self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
self.ap, )
self.data_train, self.eval_loader = (
verbose=True, self.get_eval_dataloder(
speaker_mapping=self.speaker_manager.speaker_ids) self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
self.eval_loader = (self.get_eval_dataloder( )
self.config.r, if self.config.run_eval
self.ap, else None
self.data_train, )
verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids)
if self.config.run_eval else None)
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
@ -667,8 +668,7 @@ class TrainerTTS:
def save_best_model(self) -> None: def save_best_model(self) -> None:
self.best_loss = save_best_model( self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.best_loss, self.best_loss,
self.model, self.model,
self.optimizer, self.optimizer,
@ -685,10 +685,8 @@ class TrainerTTS:
@staticmethod @staticmethod
def _setup_logger_config(log_file: str) -> None: def _setup_logger_config(log_file: str) -> None:
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
format="", )
handlers=[logging.FileHandler(log_file),
logging.StreamHandler()])
def on_epoch_start(self) -> None: # pylint: disable=no-self-use def on_epoch_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"): if hasattr(self.model, "on_epoch_start"):

View File

@ -1,9 +1,11 @@
import sys import sys
import numpy as np
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
from TTS.tts.datasets.TTSDataset import TTSDataset
import numpy as np
from TTS.tts.datasets.formatters import * from TTS.tts.datasets.formatters import *
from TTS.tts.datasets.TTSDataset import TTSDataset
#################### ####################
# UTILITIES # UTILITIES

View File

@ -7,7 +7,6 @@ from typing import List
from tqdm import tqdm from tqdm import tqdm
######################## ########################
# DATASETS # DATASETS
######################## ########################

View File

@ -4,13 +4,13 @@ import torch.nn as nn
from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class AlignTTS(nn.Module): class AlignTTS(nn.Module):

View File

@ -6,11 +6,11 @@ from torch.nn import functional as F
from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask
class GlowTTS(nn.Module): class GlowTTS(nn.Module):

View File

@ -3,13 +3,13 @@ from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class SpeedySpeech(nn.Module): class SpeedySpeech(nn.Module):

View File

@ -2,11 +2,11 @@
import torch import torch
from torch import nn from torch import nn
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.tacotron_abstract import TacotronAbstract from TTS.tts.models.tacotron_abstract import TacotronAbstract
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron(TacotronAbstract): class Tacotron(TacotronAbstract):

View File

@ -3,11 +3,11 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.tacotron_abstract import TacotronAbstract from TTS.tts.models.tacotron_abstract import TacotronAbstract
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron2(TacotronAbstract): class Tacotron2(TacotronAbstract):

View File

@ -1,5 +1,5 @@
import torch
import numpy as np import numpy as np
import torch
def _pad_data(x, length): def _pad_data(x, length):

View File

@ -1,7 +1,7 @@
import json import json
import os import os
import random import random
from typing import Union, List, Any from typing import Any, List, Union
import numpy as np import numpy as np
import torch import torch

View File

@ -11,9 +11,9 @@ import torch
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
def init_arguments(argv): def init_arguments(argv):