mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
3cb07fb6b5
commit
82fed4add2
|
@ -1,8 +1,8 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -419,6 +419,7 @@ class TTSDataset(Dataset):
|
||||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
|
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_id_mapping:
|
if self.speaker_id_mapping:
|
||||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
|
@ -258,10 +258,10 @@ class Tacotron(BaseTacotron):
|
||||||
stop_targets.float(),
|
stop_targets.float(),
|
||||||
stop_target_lengths,
|
stop_target_lengths,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs["decoder_outputs_backward"].float(),
|
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
|
||||||
outputs["alignments"].float(),
|
outputs["alignments"].float(),
|
||||||
alignment_lengths,
|
alignment_lengths,
|
||||||
outputs["alignments_backward"].float(),
|
None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
|
||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
Loading…
Reference in New Issue