From c76a6170726c3314f9930160096b54b99d190cf8 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 9 Nov 2020 13:18:35 +0100 Subject: [PATCH] linter updates --- TTS/bin/tune_wavegrad.py | 1 - TTS/speaker_encoder/model.py | 1 + TTS/vocoder/datasets/wavegrad_dataset.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index ef971dfa..375e1f1c 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -10,7 +10,6 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset -from TTS.vocoder.models.wavegrad import Wavegrad from TTS.vocoder.utils.generic_utils import setup_generator parser = argparse.ArgumentParser() diff --git a/TTS/speaker_encoder/model.py b/TTS/speaker_encoder/model.py index df0527bc..322ee42f 100644 --- a/TTS/speaker_encoder/model.py +++ b/TTS/speaker_encoder/model.py @@ -61,6 +61,7 @@ class SpeakerEncoder(nn.Module): d = torch.nn.functional.normalize(d, p=2, dim=1) return d + @torch.no_grad() def inference(self, x): d = self.layers.forward(x) if self.use_lstm_with_projection: diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 83244c89..30cf9cb3 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -111,8 +111,8 @@ class WaveGradDataset(Dataset): mel = torch.from_numpy(mel).float().squeeze(0) return (mel, audio) - - def collate_full_clips(self, batch): + @staticmethod + def collate_full_clips(batch): """This is used in tune_wavegrad.py. It pads sequences to the max length.""" max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]