mirror of https://github.com/coqui-ai/TTS.git
Make linter
This commit is contained in:
parent
0b1986384f
commit
37959ad0c7
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ class TrainingArgs(Coqpit):
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__( # pylint: disable=dangerous-default-value
|
||||||
self,
|
self,
|
||||||
args: Union[Coqpit, Namespace],
|
args: Union[Coqpit, Namespace],
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
@ -335,7 +335,9 @@ class Trainer:
|
||||||
args.parse_args(training_args)
|
args.parse_args(training_args)
|
||||||
return args, coqpit_overrides
|
return args, coqpit_overrides
|
||||||
|
|
||||||
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
|
def init_training(
|
||||||
|
self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None
|
||||||
|
): # pylint: disable=no-self-use
|
||||||
"""Initialize training and update model configs from command line arguments.
|
"""Initialize training and update model configs from command line arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -387,14 +389,13 @@ class Trainer:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
|
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
|
||||||
if isinstance(get_data_samples, Callable):
|
if callable(get_data_samples):
|
||||||
if len(signature(get_data_samples).sig.parameters) == 1:
|
if len(signature(get_data_samples).sig.parameters) == 1:
|
||||||
train_samples, eval_samples = get_data_samples(config)
|
train_samples, eval_samples = get_data_samples(config)
|
||||||
else:
|
else:
|
||||||
train_samples, eval_samples = get_data_samples()
|
train_samples, eval_samples = get_data_samples()
|
||||||
return train_samples, eval_samples
|
return train_samples, eval_samples
|
||||||
else:
|
return None, None
|
||||||
return None, None
|
|
||||||
|
|
||||||
def restore_model(
|
def restore_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -13,7 +12,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
@ -360,7 +358,7 @@ class AlignTTS(BaseTTS):
|
||||||
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def _create_logs(self, batch, outputs, ap):
|
def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use
|
||||||
model_outputs = outputs["model_outputs"]
|
model_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
mel_input = batch["mel_input"]
|
mel_input = batch["mel_input"]
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import MISSING, Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.text import make_symbols
|
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(BaseTacotron):
|
class Tacotron(BaseTacotron):
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class Tacotron2(BaseTacotron):
|
class Tacotron2(BaseTacotron):
|
||||||
|
|
|
@ -17,7 +17,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
@ -576,7 +575,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def _log(self, ap, batch, outputs, name_prefix="train"):
|
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
|
||||||
y_hat = outputs[0]["model_outputs"]
|
y_hat = outputs[0]["model_outputs"]
|
||||||
y = outputs[0]["waveform_seg"]
|
y = outputs[0]["waveform_seg"]
|
||||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import MISSING
|
|
||||||
|
|
||||||
from TTS.config import BaseAudioConfig, BaseTrainingConfig
|
from TTS.config import BaseAudioConfig, BaseTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,6 @@ from torch.nn.utils import weight_norm
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||||
from TTS.vocoder.datasets import WaveGradDataset
|
from TTS.vocoder.datasets import WaveGradDataset
|
||||||
|
|
|
@ -6,7 +6,6 @@ from TTS.tts.configs import SpeedySpeechConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.forward_tts import ForwardTTS
|
from TTS.tts.models.forward_tts import ForwardTTS
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.manage import ModelManager
|
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
|
|
Loading…
Reference in New Issue