Make style

This commit is contained in:
Eren Gölge 2022-01-25 10:40:29 +00:00
parent b3ed6ff6b7
commit 1f0c8179da
13 changed files with 42 additions and 29 deletions

View File

@ -9,7 +9,6 @@ from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
phonemizer = Gruut(language="en-us")

View File

@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit):
do_amp_to_db_mel (bool, optional):
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
pitch_fmax (float, optional):
Maximum frequency of the F0 frames. Defaults to ```640```.
pitch_fmin (float, optional):
Minimum frequency of the F0 frames. Defaults to ```0```.
trim_db (int):
Silence threshold used for silence trimming. Defaults to 45.
@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit):
spec_gain: int = 20
do_amp_to_db_linear: bool = True
do_amp_to_db_mel: bool = True
# f0 params
pitch_fmax: float = 640.0
pitch_fmin: float = 0.0
# normalization params
signal_norm: bool = True
min_level_db: int = -100

View File

@ -1,5 +1,4 @@
import collections
from email.mime import audio
import os
import random
from typing import Dict, List, Union
@ -256,7 +255,7 @@ class TTSDataset(Dataset):
new_samples = []
for item in samples:
text, wav_file, *_ = _parse_sample(item)
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
text_lenght = len(text)
new_samples += [item + [audio_length, text_lenght]]
return new_samples
@ -291,7 +290,8 @@ class TTSDataset(Dataset):
samples[offset:end_offset] = temp_items
return samples
def _select_samples_by_idx(self, idxs, samples):
@staticmethod
def _select_samples_by_idx(idxs, samples):
samples_new = []
for idx in idxs:
samples_new.append(samples[idx])
@ -307,9 +307,7 @@ class TTSDataset(Dataset):
text_lengths = [i[-1] for i in samples]
audio_lengths = [i[-2] for i in samples]
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
audio_lengths, self.min_audio_len, self.max_audio_len
)
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))

View File

@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module):
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
binary_loss_weight=None
binary_loss_weight=None,
):
loss = 0
return_dict = {}
@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module):
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
if binary_loss_weight:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
return_dict["loss_binary_alignment"] = (
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
)
else:
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss

View File

@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS):
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"],
alignment_hard=outputs["alignment_mas"],
binary_loss_weight=self.binary_loss_weight
binary_loss_weight=self.binary_loss_weight,
)
# compute duration error
durations_pred = outputs["durations"]

View File

@ -364,7 +364,9 @@ class Vits(BaseTTS):
)
upsample_rate = math.prod(self.args.upsample_rates_decoder)
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
assert (
upsample_rate == self.config.audio.hop_length
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
self.waveform_decoder = HifiganGenerator(
self.args.hidden_channels,
1,
@ -666,7 +668,7 @@ class Vits(BaseTTS):
waveform,
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
pad_short=True
pad_short=True,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
@ -688,7 +690,7 @@ class Vits(BaseTTS):
outputs.update(
{
"model_outputs": o,
"alignments" : attn.squeeze(1),
"alignments": attn.squeeze(1),
"m_p": m_p,
"logs_p": logs_p,
"z": z,
@ -951,7 +953,7 @@ class Vits(BaseTTS):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
figures, audios = self._log(self.ap, batch, outputs, "eval")
figures, audios = self._log(self.ap, batch, outputs, "eval")
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, self.ap.sample_rate)

View File

@ -68,7 +68,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
"""
# pad the input tensor if it is shorter than the segment size
if pad_short and x.shape[-1] < segment_size:
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
segments = torch.zeros_like(x[:, :, :segment_size])
@ -78,12 +78,14 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
x_i = x[i]
if pad_short and index_end > x.size(2):
# pad the sample if it is shorter than the segment size
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
segments[i] = x_i[:, index_start:index_end]
return segments
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False):
def rand_segments(
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
):
"""Create random segments based on the input lengths.
Args:
@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=
_x_lenghts[len_diff < 0] = segment_size
len_diff = _x_lenghts - segment_size + 1
else:
assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
assert all(
len_diff > 0
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
ret = segment(x, segment_indices, segment_size)
return ret, segment_indices

View File

@ -1,7 +1,6 @@
import json
import os
from typing import Dict, List
from TTS.config import check_config_and_model_args
import fsspec
import numpy as np
@ -9,6 +8,8 @@ import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import check_config_and_model_args
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information

View File

@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
fig, ax = plt.subplots()
x = np.array(range(len(chars)))
my_xticks = [c for c in chars]
my_xticks = chars
plt.xticks(x, my_xticks)
ax.set_xlabel("characters")

View File

@ -5,10 +5,8 @@ import numpy as np
import pysbd
import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import

View File

@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
class GAN(BaseVocoder):
def __init__(self, config: Coqpit, ap: AudioProcessor=None):
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
@ -306,7 +306,7 @@ class GAN(BaseVocoder):
x, y = batch
return {"input": x, "waveform": y}
def get_data_loader( # pylint: disable=no-self-use
def get_data_loader( # pylint: disable=no-self-use, unused-argument
self,
config: Coqpit,
assets: Dict,

View File

@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase):
max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len,
start_by_longest=start_by_longest
start_by_longest=start_by_longest,
)
dataloader = DataLoader(
dataset,

View File

@ -1,6 +1,6 @@
import torch as T
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
def average_over_durations_test(): # pylint: disable=no-self-use
@ -57,12 +57,12 @@ def rand_segments_test():
assert segments.shape == (2, 3, 3)
assert all(seg_idxs >= 0), seg_idxs
try:
segments, _ = rand_segments(x, x_lens, segment_size=5)
segments, _ = rand_segments(x, x_lens, segment_size=5)
raise Exception("Should have failed")
except:
pass
x_lens_back = x_lens.clone()
segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
assert segments.shape == (2, 3, 5)
assert all(seg_idxs >= 0), seg_idxs
assert all(x_lens_back == x_lens)