mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
3cb07fb6b5
commit
82fed4add2
|
@ -1,8 +1,8 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -419,6 +419,7 @@ class TTSDataset(Dataset):
|
||||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
|
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_id_mapping:
|
if self.speaker_id_mapping:
|
||||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||||
|
|
|
@ -100,7 +100,7 @@ class AlignTTS(BaseTTS):
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.speaker_manager = speaker_manager
|
self.speaker_manager = speaker_manager
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
|
@ -258,10 +258,10 @@ class Tacotron(BaseTacotron):
|
||||||
stop_targets.float(),
|
stop_targets.float(),
|
||||||
stop_target_lengths,
|
stop_target_lengths,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs["decoder_outputs_backward"].float(),
|
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
|
||||||
outputs["alignments"].float(),
|
outputs["alignments"].float(),
|
||||||
alignment_lengths,
|
alignment_lengths,
|
||||||
outputs["alignments_backward"].float(),
|
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -237,7 +238,7 @@ class Tacotron2(BaseTacotron):
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def train_step(self, batch:Dict, criterion:torch.nn.Module):
|
def train_step(self, batch: Dict, criterion: torch.nn.Module):
|
||||||
"""A single training step. Forward pass and loss computation.
|
"""A single training step. Forward pass and loss computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -216,7 +216,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue