Make style

This commit is contained in:
Eren Gölge 2021-10-21 16:05:51 +00:00
parent 3cb07fb6b5
commit 82fed4add2
7 changed files with 9 additions and 7 deletions

View File

@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.models.forward_tts import ForwardTTSArgs
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass

View File

@ -419,6 +419,7 @@ class TTSDataset(Dataset):
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else:
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]

View File

@ -100,7 +100,7 @@ class AlignTTS(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager

View File

@ -1,6 +1,6 @@
import copy
from abc import abstractmethod
from typing import Dict, List
from typing import Dict
import torch
from coqpit import Coqpit

View File

@ -258,10 +258,10 @@ class Tacotron(BaseTacotron):
stop_targets.float(),
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"].float(),
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
alignment_lengths,
outputs["alignments_backward"].float(),
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
text_lengths,
)

View File

@ -1,6 +1,7 @@
# coding: utf-8
from typing import Dict
import torch
from coqpit import Coqpit
from torch import nn
@ -237,7 +238,7 @@ class Tacotron2(BaseTacotron):
}
return outputs
def train_step(self, batch:Dict, criterion:torch.nn.Module):
def train_step(self, batch: Dict, criterion: torch.nn.Module):
"""A single training step. Forward pass and loss computation.
Args:

View File

@ -216,7 +216,7 @@ class Vits(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)