Make style

This commit is contained in:
Eren Gölge 2022-01-25 10:40:29 +00:00
parent b3ed6ff6b7
commit 1f0c8179da
13 changed files with 42 additions and 29 deletions

View File

@ -9,7 +9,6 @@ from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
phonemizer = Gruut(language="en-us") phonemizer = Gruut(language="en-us")

View File

@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit):
do_amp_to_db_mel (bool, optional): do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
pitch_fmax (float, optional):
Maximum frequency of the F0 frames. Defaults to ```640```.
pitch_fmin (float, optional):
Minimum frequency of the F0 frames. Defaults to ```0```.
trim_db (int): trim_db (int):
Silence threshold used for silence trimming. Defaults to 45. Silence threshold used for silence trimming. Defaults to 45.
@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit):
spec_gain: int = 20 spec_gain: int = 20
do_amp_to_db_linear: bool = True do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True do_amp_to_db_mel: bool = True
# f0 params
pitch_fmax: float = 640.0
pitch_fmin: float = 0.0
# normalization params # normalization params
signal_norm: bool = True signal_norm: bool = True
min_level_db: int = -100 min_level_db: int = -100

View File

@ -1,5 +1,4 @@
import collections import collections
from email.mime import audio
import os import os
import random import random
from typing import Dict, List, Union from typing import Dict, List, Union
@ -291,7 +290,8 @@ class TTSDataset(Dataset):
samples[offset:end_offset] = temp_items samples[offset:end_offset] = temp_items
return samples return samples
def _select_samples_by_idx(self, idxs, samples): @staticmethod
def _select_samples_by_idx(idxs, samples):
samples_new = [] samples_new = []
for idx in idxs: for idx in idxs:
samples_new.append(samples[idx]) samples_new.append(samples[idx])
@ -307,9 +307,7 @@ class TTSDataset(Dataset):
text_lengths = [i[-1] for i in samples] text_lengths = [i[-1] for i in samples]
audio_lengths = [i[-2] for i in samples] audio_lengths = [i[-2] for i in samples]
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.filter_by_length( audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
audio_lengths, self.min_audio_len, self.max_audio_len
)
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))

View File

@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module):
alignment_logprob=None, alignment_logprob=None,
alignment_hard=None, alignment_hard=None,
alignment_soft=None, alignment_soft=None,
binary_loss_weight=None binary_loss_weight=None,
): ):
loss = 0 loss = 0
return_dict = {} return_dict = {}
@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module):
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight: if binary_loss_weight:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight return_dict["loss_binary_alignment"] = (
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
)
else: else:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss

View File

@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS):
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"], alignment_soft=outputs["alignment_soft"],
alignment_hard=outputs["alignment_mas"], alignment_hard=outputs["alignment_mas"],
binary_loss_weight=self.binary_loss_weight binary_loss_weight=self.binary_loss_weight,
) )
# compute duration error # compute duration error
durations_pred = outputs["durations"] durations_pred = outputs["durations"]

View File

@ -364,7 +364,9 @@ class Vits(BaseTTS):
) )
upsample_rate = math.prod(self.args.upsample_rates_decoder) upsample_rate = math.prod(self.args.upsample_rates_decoder)
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" assert (
upsample_rate == self.config.audio.hop_length
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
self.waveform_decoder = HifiganGenerator( self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels, self.args.hidden_channels,
1, 1,
@ -666,7 +668,7 @@ class Vits(BaseTTS):
waveform, waveform,
slice_ids * self.config.audio.hop_length, slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
pad_short=True pad_short=True,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:

View File

@ -83,7 +83,9 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
return segments return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False): def rand_segments(
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
):
"""Create random segments based on the input lengths. """Create random segments based on the input lengths.
Args: Args:
@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=
_x_lenghts[len_diff < 0] = segment_size _x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1 len_diff = _x_lenghts - segment_size + 1
else: else:
assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size) ret = segment(x, segment_indices, segment_size)
return ret, segment_indices return ret, segment_indices

View File

@ -1,7 +1,6 @@
import json import json
import os import os
from typing import Dict, List from typing import Dict, List
from TTS.config import check_config_and_model_args
import fsspec import fsspec
import numpy as np import numpy as np
@ -9,6 +8,8 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import check_config_and_model_args
class LanguageManager: class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information

View File

@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
fig, ax = plt.subplots() fig, ax = plt.subplots()
x = np.array(range(len(chars))) x = np.array(range(len(chars)))
my_xticks = [c for c in chars] my_xticks = chars
plt.xticks(x, my_xticks) plt.xticks(x, my_xticks)
ax.set_xlabel("characters") ax.set_xlabel("characters")

View File

@ -5,10 +5,8 @@ import numpy as np
import pysbd import pysbd
import torch import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import

View File

@ -306,7 +306,7 @@ class GAN(BaseVocoder):
x, y = batch x, y = batch
return {"input": x, "waveform": y} return {"input": x, "waveform": y}
def get_data_loader( # pylint: disable=no-self-use def get_data_loader( # pylint: disable=no-self-use, unused-argument
self, self,
config: Coqpit, config: Coqpit,
assets: Dict, assets: Dict,

View File

@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase):
max_text_len=c.max_text_len, max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len, min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len, max_audio_len=c.max_audio_len,
start_by_longest=start_by_longest start_by_longest=start_by_longest,
) )
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,

View File

@ -1,6 +1,6 @@
import torch as T import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use def average_over_durations_test(): # pylint: disable=no-self-use