mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
b3ed6ff6b7
commit
1f0c8179da
|
@ -9,7 +9,6 @@ from TTS.config import load_config
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
|
||||
|
||||
phonemizer = Gruut(language="en-us")
|
||||
|
||||
|
||||
|
|
|
@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit):
|
|||
do_amp_to_db_mel (bool, optional):
|
||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||
|
||||
pitch_fmax (float, optional):
|
||||
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||
|
||||
pitch_fmin (float, optional):
|
||||
Minimum frequency of the F0 frames. Defaults to ```0```.
|
||||
|
||||
trim_db (int):
|
||||
Silence threshold used for silence trimming. Defaults to 45.
|
||||
|
||||
|
@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit):
|
|||
spec_gain: int = 20
|
||||
do_amp_to_db_linear: bool = True
|
||||
do_amp_to_db_mel: bool = True
|
||||
# f0 params
|
||||
pitch_fmax: float = 640.0
|
||||
pitch_fmin: float = 0.0
|
||||
# normalization params
|
||||
signal_norm: bool = True
|
||||
min_level_db: int = -100
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import collections
|
||||
from email.mime import audio
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
|
@ -256,7 +255,7 @@ class TTSDataset(Dataset):
|
|||
new_samples = []
|
||||
for item in samples:
|
||||
text, wav_file, *_ = _parse_sample(item)
|
||||
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
audio_length = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
||||
text_lenght = len(text)
|
||||
new_samples += [item + [audio_length, text_lenght]]
|
||||
return new_samples
|
||||
|
@ -291,7 +290,8 @@ class TTSDataset(Dataset):
|
|||
samples[offset:end_offset] = temp_items
|
||||
return samples
|
||||
|
||||
def _select_samples_by_idx(self, idxs, samples):
|
||||
@staticmethod
|
||||
def _select_samples_by_idx(idxs, samples):
|
||||
samples_new = []
|
||||
for idx in idxs:
|
||||
samples_new.append(samples[idx])
|
||||
|
@ -307,9 +307,7 @@ class TTSDataset(Dataset):
|
|||
text_lengths = [i[-1] for i in samples]
|
||||
audio_lengths = [i[-2] for i in samples]
|
||||
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
|
||||
audio_lengths, self.min_audio_len, self.max_audio_len
|
||||
)
|
||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
|
||||
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
||||
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||
|
||||
|
|
|
@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module):
|
|||
alignment_logprob=None,
|
||||
alignment_hard=None,
|
||||
alignment_soft=None,
|
||||
binary_loss_weight=None
|
||||
binary_loss_weight=None,
|
||||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module):
|
|||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
if binary_loss_weight:
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
return_dict["loss_binary_alignment"] = (
|
||||
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
)
|
||||
else:
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
|
||||
|
|
|
@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS):
|
|||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||
alignment_soft=outputs["alignment_soft"],
|
||||
alignment_hard=outputs["alignment_mas"],
|
||||
binary_loss_weight=self.binary_loss_weight
|
||||
binary_loss_weight=self.binary_loss_weight,
|
||||
)
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
|
|
|
@ -364,7 +364,9 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
upsample_rate = math.prod(self.args.upsample_rates_decoder)
|
||||
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
|
||||
assert (
|
||||
upsample_rate == self.config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
|
@ -666,7 +668,7 @@ class Vits(BaseTTS):
|
|||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short=True
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
|
@ -688,7 +690,7 @@ class Vits(BaseTTS):
|
|||
outputs.update(
|
||||
{
|
||||
"model_outputs": o,
|
||||
"alignments" : attn.squeeze(1),
|
||||
"alignments": attn.squeeze(1),
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"z": z,
|
||||
|
@ -951,7 +953,7 @@ class Vits(BaseTTS):
|
|||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
||||
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
|||
"""
|
||||
# pad the input tensor if it is shorter than the segment size
|
||||
if pad_short and x.shape[-1] < segment_size:
|
||||
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
|
||||
x = torch.nn.functional.pad(x, (0, segment_size - x.size(2)))
|
||||
|
||||
segments = torch.zeros_like(x[:, :, :segment_size])
|
||||
|
||||
|
@ -78,12 +78,14 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
|||
x_i = x[i]
|
||||
if pad_short and index_end > x.size(2):
|
||||
# pad the sample if it is shorter than the segment size
|
||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||
segments[i] = x_i[:, index_start:index_end]
|
||||
return segments
|
||||
|
||||
|
||||
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False):
|
||||
def rand_segments(
|
||||
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
|
||||
):
|
||||
"""Create random segments based on the input lengths.
|
||||
|
||||
Args:
|
||||
|
@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=
|
|||
_x_lenghts[len_diff < 0] = segment_size
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
else:
|
||||
assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||
assert all(
|
||||
len_diff > 0
|
||||
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||
ret = segment(x, segment_indices, segment_size)
|
||||
return ret, segment_indices
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from TTS.config import check_config_and_model_args
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -9,6 +8,8 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import check_config_and_model_args
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
|
|
|
@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
|
|||
fig, ax = plt.subplots()
|
||||
|
||||
x = np.array(range(len(chars)))
|
||||
my_xticks = [c for c in chars]
|
||||
my_xticks = chars
|
||||
plt.xticks(x, my_xticks)
|
||||
|
||||
ax.set_xlabel("characters")
|
||||
|
|
|
@ -5,10 +5,8 @@ import numpy as np
|
|||
import pysbd
|
||||
import torch
|
||||
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
|
|
|
@ -19,7 +19,7 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
|||
|
||||
|
||||
class GAN(BaseVocoder):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor=None):
|
||||
def __init__(self, config: Coqpit, ap: AudioProcessor = None):
|
||||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
|
@ -306,7 +306,7 @@ class GAN(BaseVocoder):
|
|||
x, y = batch
|
||||
return {"input": x, "waveform": y}
|
||||
|
||||
def get_data_loader( # pylint: disable=no-self-use
|
||||
def get_data_loader( # pylint: disable=no-self-use, unused-argument
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
|
|
|
@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
max_text_len=c.max_text_len,
|
||||
min_audio_len=c.min_audio_len,
|
||||
max_audio_len=c.max_audio_len,
|
||||
start_by_longest=start_by_longest
|
||||
start_by_longest=start_by_longest,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch as T
|
||||
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
|
||||
|
||||
|
||||
def average_over_durations_test(): # pylint: disable=no-self-use
|
||||
|
@ -57,12 +57,12 @@ def rand_segments_test():
|
|||
assert segments.shape == (2, 3, 3)
|
||||
assert all(seg_idxs >= 0), seg_idxs
|
||||
try:
|
||||
segments, _ = rand_segments(x, x_lens, segment_size=5)
|
||||
segments, _ = rand_segments(x, x_lens, segment_size=5)
|
||||
raise Exception("Should have failed")
|
||||
except:
|
||||
pass
|
||||
x_lens_back = x_lens.clone()
|
||||
segments, seg_idxs= rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
|
||||
segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
|
||||
assert segments.shape == (2, 3, 5)
|
||||
assert all(seg_idxs >= 0), seg_idxs
|
||||
assert all(x_lens_back == x_lens)
|
||||
|
|
Loading…
Reference in New Issue