linter updates

This commit is contained in:
erogol 2020-11-09 13:18:35 +01:00
parent ea976b0543
commit c76a617072
3 changed files with 3 additions and 3 deletions

View File

@ -10,7 +10,6 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.models.wavegrad import Wavegrad
from TTS.vocoder.utils.generic_utils import setup_generator
parser = argparse.ArgumentParser()

View File

@ -61,6 +61,7 @@ class SpeakerEncoder(nn.Module):
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
def inference(self, x):
d = self.layers.forward(x)
if self.use_lstm_with_projection:

View File

@ -111,8 +111,8 @@ class WaveGradDataset(Dataset):
mel = torch.from_numpy(mel).float().squeeze(0)
return (mel, audio)
def collate_full_clips(self, batch):
@staticmethod
def collate_full_clips(batch):
"""This is used in tune_wavegrad.py.
It pads sequences to the max length."""
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]