mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
3cb07fb6b5
commit
82fed4add2
|
@ -1,8 +1,8 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -419,6 +419,7 @@ class TTSDataset(Dataset):
|
|||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
d_vectors = None
|
||||
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_id_mapping:
|
||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||
|
|
|
@ -100,7 +100,7 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
|
||||
super().__init__(config)
|
||||
self.speaker_manager = speaker_manager
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
|
|
@ -258,10 +258,10 @@ class Tacotron(BaseTacotron):
|
|||
stop_targets.float(),
|
||||
stop_target_lengths,
|
||||
mel_lengths,
|
||||
outputs["decoder_outputs_backward"].float(),
|
||||
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
|
||||
outputs["alignments"].float(),
|
||||
alignment_lengths,
|
||||
outputs["alignments_backward"].float(),
|
||||
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -237,7 +238,7 @@ class Tacotron2(BaseTacotron):
|
|||
}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch:Dict, criterion:torch.nn.Module):
|
||||
def train_step(self, batch: Dict, criterion: torch.nn.Module):
|
||||
"""A single training step. Forward pass and loss computation.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -216,7 +216,7 @@ class Vits(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue