mirror of https://github.com/coqui-ai/TTS.git
linter updates
This commit is contained in:
parent
ea976b0543
commit
c76a617072
|
@ -10,7 +10,6 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -61,6 +61,7 @@ class SpeakerEncoder(nn.Module):
|
|||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x):
|
||||
d = self.layers.forward(x)
|
||||
if self.use_lstm_with_projection:
|
||||
|
|
|
@ -111,8 +111,8 @@ class WaveGradDataset(Dataset):
|
|||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
return (mel, audio)
|
||||
|
||||
|
||||
def collate_full_clips(self, batch):
|
||||
@staticmethod
|
||||
def collate_full_clips(batch):
|
||||
"""This is used in tune_wavegrad.py.
|
||||
It pads sequences to the max length."""
|
||||
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
|
||||
|
|
Loading…
Reference in New Issue