refix linter

This commit is contained in:
WeberJulian 2021-07-13 23:12:18 +02:00
parent 7d92b30946
commit c79a82ed07
6 changed files with 13 additions and 8 deletions

View File

@ -22,7 +22,6 @@ from torch.utils.data import DataLoader
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.datasets import load_meta_data from TTS.tts.datasets import load_meta_data
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.vocoder.models.wavegrad import Wavegrad
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.callbacks import TrainerCallback from TTS.utils.callbacks import TrainerCallback
@ -41,6 +40,7 @@ from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.models.wavegrad import Wavegrad
if platform.system() != "Windows": if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973 # https://github.com/pytorch/pytorch/issues/973
@ -767,13 +767,14 @@ class Trainer:
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if isinstance(self.model, Wavegrad): if isinstance(self.model, Wavegrad):
return None # TODO: Fix inference on WaveGrad return None # TODO: Fix inference on WaveGrad
elif hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda) figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda)
else: else:
figures, audios = self.model.test_run(self.ap, self.use_cuda) figures, audios = self.model.test_run(self.ap, self.use_cuda)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)
return None
def _fit(self) -> None: def _fit(self) -> None:
"""🏃 train -> evaluate -> test for the number of epochs.""" """🏃 train -> evaluate -> test for the number of epochs."""

View File

@ -261,7 +261,9 @@ class Wavegrad(BaseModel):
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda): # pylint: disable=unused-argument def test_run(
self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda
): # pylint: disable=unused-argument
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = self.config["test_noise_schedule"] noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])

View File

@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
def run_test_train(): def run_test_train():
command = ( command = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
@ -17,6 +18,7 @@ def run_test_train():
) )
run_cli(command) run_cli(command)
config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json") config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")