From 1f0c8179da4965dcf6f2048cdbcea710aef3875a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 Jan 2022 10:40:29 +0000 Subject: [PATCH] Make style --- TTS/bin/find_unique_phonemes.py | 1 - TTS/config/shared_configs.py | 9 +++++++++ TTS/tts/datasets/dataset.py | 10 ++++------ TTS/tts/layers/losses.py | 6 ++++-- TTS/tts/models/forward_tts.py | 2 +- TTS/tts/models/vits.py | 10 ++++++---- TTS/tts/utils/helpers.py | 12 ++++++++---- TTS/tts/utils/languages.py | 3 ++- TTS/tts/utils/visual.py | 2 +- TTS/utils/synthesizer.py | 4 +--- TTS/vocoder/models/gan.py | 4 ++-- tests/data_tests/test_loader.py | 2 +- tests/tts_tests/test_helpers.py | 6 +++--- 13 files changed, 42 insertions(+), 29 deletions(-) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index c5501552..8fe48b2f 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -9,7 +9,6 @@ from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut - phonemizer = Gruut(language="en-us") diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 217282ad..392f10af 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit): do_amp_to_db_mel (bool, optional): enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```0```. + trim_db (int): Silence threshold used for silence trimming. Defaults to 45. @@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit): spec_gain: int = 20 do_amp_to_db_linear: bool = True do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 0.0 # normalization params signal_norm: bool = True min_level_db: int = -100 diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 62e146e0..499e6b7b 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,5 +1,4 @@ import collections -from email.mime import audio import os import random from typing import Dict, List, Union @@ -256,7 +255,7 @@ class TTSDataset(Dataset): new_samples = [] for item in samples: text, wav_file, *_ = _parse_sample(item) - audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio text_lenght = len(text) new_samples += [item + [audio_length, text_lenght]] return new_samples @@ -291,7 +290,8 @@ class TTSDataset(Dataset): samples[offset:end_offset] = temp_items return samples - def _select_samples_by_idx(self, idxs, samples): + @staticmethod + def _select_samples_by_idx(idxs, samples): samples_new = [] for idx in idxs: samples_new.append(samples[idx]) @@ -307,9 +307,7 @@ class TTSDataset(Dataset): text_lengths = [i[-1] for i in samples] audio_lengths = [i[-2] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) - audio_ignore_idx, audio_keep_idx = self.filter_by_length( - audio_lengths, self.min_audio_len, self.max_audio_len - ) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index f4a472ad..827da751 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module): alignment_logprob=None, alignment_hard=None, alignment_soft=None, - binary_loss_weight=None + binary_loss_weight=None, ): loss = 0 return_dict = {} @@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module): binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: - return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) else: return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index bb8640a3..8d554f76 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS): alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_soft=outputs["alignment_soft"], alignment_hard=outputs["alignment_mas"], - binary_loss_weight=self.binary_loss_weight + binary_loss_weight=self.binary_loss_weight, ) # compute duration error durations_pred = outputs["durations"] diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 222bbca5..cb4499fb 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -364,7 +364,9 @@ class Vits(BaseTTS): ) upsample_rate = math.prod(self.args.upsample_rates_decoder) - assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" + assert ( + upsample_rate == self.config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -666,7 +668,7 @@ class Vits(BaseTTS): waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short=True + pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -688,7 +690,7 @@ class Vits(BaseTTS): outputs.update( { "model_outputs": o, - "alignments" : attn.squeeze(1), + "alignments": attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, "z": z, @@ -951,7 +953,7 @@ class Vits(BaseTTS): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, batch, outputs, "eval") logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 32513377..c2e7f561 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -68,7 +68,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ """ # pad the input tensor if it is shorter than the segment size if pad_short and x.shape[-1] < segment_size: - x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) segments = torch.zeros_like(x[:, :, :segment_size]) @@ -78,12 +78,14 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ x_i = x[i] if pad_short and index_end > x.size(2): # pad the sample if it is shorter than the segment size - x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) segments[i] = x_i[:, index_start:index_end] return segments -def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False): +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): """Create random segments based on the input lengths. Args: @@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size= _x_lenghts[len_diff < 0] = segment_size len_diff = _x_lenghts - segment_size + 1 else: - assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() ret = segment(x, segment_indices, segment_size) return ret, segment_indices diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 54ba40b2..19708c13 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,7 +1,6 @@ import json import os from typing import Dict, List -from TTS.config import check_config_and_model_args import fsspec import numpy as np @@ -9,6 +8,8 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler +from TTS.config import check_config_and_model_args + class LanguageManager: """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 4fd1f19c..78c12981 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): fig, ax = plt.subplots() x = np.array(range(len(chars))) - my_xticks = [c for c in chars] + my_xticks = chars plt.xticks(x, my_xticks) ax.set_xlabel("characters") diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index ddc2a6a5..6821e975 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,10 +5,8 @@ import numpy as np import pysbd import torch -from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config +from TTS.config import load_config from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 7e03e94f..6978f0e7 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results class GAN(BaseVocoder): - def __init__(self, config: Coqpit, ap: AudioProcessor=None): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. It also helps mixing and matching different generator and disciminator networks easily. @@ -306,7 +306,7 @@ class GAN(BaseVocoder): x, y = batch return {"input": x, "waveform": y} - def get_data_loader( # pylint: disable=no-self-use + def get_data_loader( # pylint: disable=no-self-use, unused-argument self, config: Coqpit, assets: Dict, diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index f96154bc..4d8cc68a 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase): max_text_len=c.max_text_len, min_audio_len=c.min_audio_len, max_audio_len=c.max_audio_len, - start_by_longest=start_by_longest + start_by_longest=start_by_longest, ) dataloader = DataLoader( dataset, diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 708ecbf5..23bb440a 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -1,6 +1,6 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments +from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask def average_over_durations_test(): # pylint: disable=no-self-use @@ -57,12 +57,12 @@ def rand_segments_test(): assert segments.shape == (2, 3, 3) assert all(seg_idxs >= 0), seg_idxs try: - segments, _ = rand_segments(x, x_lens, segment_size=5) + segments, _ = rand_segments(x, x_lens, segment_size=5) raise Exception("Should have failed") except: pass x_lens_back = x_lens.clone() - segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) + segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True) assert segments.shape == (2, 3, 5) assert all(seg_idxs >= 0), seg_idxs assert all(x_lens_back == x_lens)