Make style

This commit is contained in:
Eren Gölge 2021-10-21 16:05:51 +00:00
parent 3cb07fb6b5
commit 82fed4add2
7 changed files with 9 additions and 7 deletions

View File

@ -1,8 +1,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
from TTS.tts.models.forward_tts import ForwardTTSArgs
from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass @dataclass

View File

@ -419,6 +419,7 @@ class TTSDataset(Dataset):
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
else: else:
d_vectors = None d_vectors = None
# get numerical speaker ids from speaker names # get numerical speaker ids from speaker names
if self.speaker_id_mapping: if self.speaker_id_mapping:
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]

View File

@ -1,6 +1,6 @@
import copy import copy
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List from typing import Dict
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit

View File

@ -258,10 +258,10 @@ class Tacotron(BaseTacotron):
stop_targets.float(), stop_targets.float(),
stop_target_lengths, stop_target_lengths,
mel_lengths, mel_lengths,
outputs["decoder_outputs_backward"].float(), None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(), outputs["alignments"].float(),
alignment_lengths, alignment_lengths,
outputs["alignments_backward"].float(), None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(),
text_lengths, text_lengths,
) )

View File

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
from typing import Dict from typing import Dict
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn