format with black

This commit is contained in:
Eren Gölge 2021-04-09 00:54:59 +02:00
parent e5b9607bc3
commit 18d9ec8036
12 changed files with 27 additions and 11 deletions

View File

@ -7,8 +7,11 @@ import tensorflow as tf
import torch
from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (compare_torch_tf, convert_tf_name,
transfer_weights_torch_to_tf)
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf,
convert_tf_name,
transfer_weights_torch_to_tf,
)
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator

View File

@ -4,6 +4,7 @@
import argparse
import sys
from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path

View File

@ -17,8 +17,13 @@ from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import (
count_parameters,
create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict,
)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger

View File

@ -8,6 +8,7 @@ import traceback
from random import randrange
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader

View File

@ -9,6 +9,7 @@ from random import randrange
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader

View File

@ -26,8 +26,14 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
from TTS.utils.training import (
NoamLR,
adam_weight_decay,
check_update,
gradual_training_scheduler,
set_weight_decay,
setup_torch_training_env,
)
use_cuda, num_gpus = setup_torch_training_env(True, False)

View File

@ -9,6 +9,7 @@ import traceback
from inspect import signature
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader

View File

@ -8,6 +8,7 @@ import traceback
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam

View File

@ -25,7 +25,6 @@ from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# from torch.utils.data.distributed import DistributedSampler
use_cuda, num_gpus = setup_torch_training_env(True, True)

View File

@ -147,9 +147,7 @@ class Encoder(nn.Module):
in_hidden_channels == out_channels
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
# pylint: disable=unexpected-keyword-arg
self.encoder = FFTransformerBlock(
in_hidden_channels, **encoder_params
)
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
else:
raise NotImplementedError(" [!] unknown encoder type.")

View File

@ -6,6 +6,7 @@ import torch
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import load_speaker_mapping
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence

View File

@ -9,7 +9,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL
logging.getLogger("tensorflow").setLevel(logging.FATAL)
# pylint: disable=too-many-ancestors
# pylint: disable=abstract-method
class MelganGenerator(tf.keras.models.Model):