format with black

This commit is contained in:
Eren Gölge 2021-04-09 00:54:59 +02:00
parent e5b9607bc3
commit 18d9ec8036
12 changed files with 27 additions and 11 deletions

View File

@ -7,8 +7,11 @@ import tensorflow as tf
import torch import torch
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (compare_torch_tf, convert_tf_name, from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
transfer_weights_torch_to_tf) compare_torch_tf,
convert_tf_name,
transfer_weights_torch_to_tf,
)
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator from TTS.vocoder.utils.generic_utils import setup_generator

View File

@ -4,6 +4,7 @@
import argparse import argparse
import sys import sys
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument # pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path from pathlib import Path

View File

@ -17,8 +17,13 @@ from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder
from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, get_git_branch, from TTS.utils.generic_utils import (
remove_experiment_folder, set_init_dict) count_parameters,
create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict,
)
from TTS.utils.io import copy_model_files, load_config from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger

View File

@ -8,6 +8,7 @@ import traceback
from random import randrange from random import randrange
import torch import torch
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -9,6 +9,7 @@ from random import randrange
import numpy as np import numpy as np
import torch import torch
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -26,8 +26,14 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, gradual_training_scheduler, set_weight_decay, from TTS.utils.training import (
setup_torch_training_env) NoamLR,
adam_weight_decay,
check_update,
gradual_training_scheduler,
set_weight_decay,
setup_torch_training_env,
)
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)

View File

@ -9,6 +9,7 @@ import traceback
from inspect import signature from inspect import signature
import torch import torch
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -8,6 +8,7 @@ import traceback
import numpy as np import numpy as np
import torch import torch
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam from torch.optim import Adam

View File

@ -25,7 +25,6 @@ from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# from torch.utils.data.distributed import DistributedSampler # from torch.utils.data.distributed import DistributedSampler
use_cuda, num_gpus = setup_torch_training_env(True, True) use_cuda, num_gpus = setup_torch_training_env(True, True)

View File

@ -147,9 +147,7 @@ class Encoder(nn.Module):
in_hidden_channels == out_channels in_hidden_channels == out_channels
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
# pylint: disable=unexpected-keyword-arg # pylint: disable=unexpected-keyword-arg
self.encoder = FFTransformerBlock( self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
in_hidden_channels, **encoder_params
)
else: else:
raise NotImplementedError(" [!] unknown encoder type.") raise NotImplementedError(" [!] unknown encoder type.")

View File

@ -6,6 +6,7 @@ import torch
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import load_speaker_mapping from TTS.tts.utils.speakers import load_speaker_mapping
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.synthesis import synthesis, trim_silence

View File

@ -9,7 +9,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL
logging.getLogger("tensorflow").setLevel(logging.FATAL) logging.getLogger("tensorflow").setLevel(logging.FATAL)
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
# pylint: disable=abstract-method # pylint: disable=abstract-method
class MelganGenerator(tf.keras.models.Model): class MelganGenerator(tf.keras.models.Model):