mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
b3ed6ff6b7
commit
1f0c8179da
|
@ -9,7 +9,6 @@ from TTS.config import load_config
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||||
|
|
||||||
|
|
||||||
phonemizer = Gruut(language="en-us")
|
phonemizer = Gruut(language="en-us")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,12 @@ class BaseAudioConfig(Coqpit):
|
||||||
do_amp_to_db_mel (bool, optional):
|
do_amp_to_db_mel (bool, optional):
|
||||||
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
|
||||||
|
|
||||||
|
pitch_fmax (float, optional):
|
||||||
|
Maximum frequency of the F0 frames. Defaults to ```640```.
|
||||||
|
|
||||||
|
pitch_fmin (float, optional):
|
||||||
|
Minimum frequency of the F0 frames. Defaults to ```0```.
|
||||||
|
|
||||||
trim_db (int):
|
trim_db (int):
|
||||||
Silence threshold used for silence trimming. Defaults to 45.
|
Silence threshold used for silence trimming. Defaults to 45.
|
||||||
|
|
||||||
|
@ -135,6 +141,9 @@ class BaseAudioConfig(Coqpit):
|
||||||
spec_gain: int = 20
|
spec_gain: int = 20
|
||||||
do_amp_to_db_linear: bool = True
|
do_amp_to_db_linear: bool = True
|
||||||
do_amp_to_db_mel: bool = True
|
do_amp_to_db_mel: bool = True
|
||||||
|
# f0 params
|
||||||
|
pitch_fmax: float = 640.0
|
||||||
|
pitch_fmin: float = 0.0
|
||||||
# normalization params
|
# normalization params
|
||||||
signal_norm: bool = True
|
signal_norm: bool = True
|
||||||
min_level_db: int = -100
|
min_level_db: int = -100
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import collections
|
import collections
|
||||||
from email.mime import audio
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
@ -291,7 +290,8 @@ class TTSDataset(Dataset):
|
||||||
samples[offset:end_offset] = temp_items
|
samples[offset:end_offset] = temp_items
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def _select_samples_by_idx(self, idxs, samples):
|
@staticmethod
|
||||||
|
def _select_samples_by_idx(idxs, samples):
|
||||||
samples_new = []
|
samples_new = []
|
||||||
for idx in idxs:
|
for idx in idxs:
|
||||||
samples_new.append(samples[idx])
|
samples_new.append(samples[idx])
|
||||||
|
@ -307,9 +307,7 @@ class TTSDataset(Dataset):
|
||||||
text_lengths = [i[-1] for i in samples]
|
text_lengths = [i[-1] for i in samples]
|
||||||
audio_lengths = [i[-2] for i in samples]
|
audio_lengths = [i[-2] for i in samples]
|
||||||
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
|
||||||
audio_ignore_idx, audio_keep_idx = self.filter_by_length(
|
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
|
||||||
audio_lengths, self.min_audio_len, self.max_audio_len
|
|
||||||
)
|
|
||||||
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
|
||||||
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))
|
||||||
|
|
||||||
|
|
|
@ -740,7 +740,7 @@ class ForwardTTSLoss(nn.Module):
|
||||||
alignment_logprob=None,
|
alignment_logprob=None,
|
||||||
alignment_hard=None,
|
alignment_hard=None,
|
||||||
alignment_soft=None,
|
alignment_soft=None,
|
||||||
binary_loss_weight=None
|
binary_loss_weight=None,
|
||||||
):
|
):
|
||||||
loss = 0
|
loss = 0
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
@ -774,7 +774,9 @@ class ForwardTTSLoss(nn.Module):
|
||||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
if binary_loss_weight:
|
if binary_loss_weight:
|
||||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
return_dict["loss_binary_alignment"] = (
|
||||||
|
self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
|
||||||
|
|
|
@ -646,7 +646,7 @@ class ForwardTTS(BaseTTS):
|
||||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||||
alignment_soft=outputs["alignment_soft"],
|
alignment_soft=outputs["alignment_soft"],
|
||||||
alignment_hard=outputs["alignment_mas"],
|
alignment_hard=outputs["alignment_mas"],
|
||||||
binary_loss_weight=self.binary_loss_weight
|
binary_loss_weight=self.binary_loss_weight,
|
||||||
)
|
)
|
||||||
# compute duration error
|
# compute duration error
|
||||||
durations_pred = outputs["durations"]
|
durations_pred = outputs["durations"]
|
||||||
|
|
|
@ -364,7 +364,9 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
upsample_rate = math.prod(self.args.upsample_rates_decoder)
|
upsample_rate = math.prod(self.args.upsample_rates_decoder)
|
||||||
assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
|
assert (
|
||||||
|
upsample_rate == self.config.audio.hop_length
|
||||||
|
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}"
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
1,
|
1,
|
||||||
|
@ -666,7 +668,7 @@ class Vits(BaseTTS):
|
||||||
waveform,
|
waveform,
|
||||||
slice_ids * self.config.audio.hop_length,
|
slice_ids * self.config.audio.hop_length,
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
pad_short=True
|
pad_short=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||||
|
|
|
@ -83,7 +83,9 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False):
|
def rand_segments(
|
||||||
|
x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False
|
||||||
|
):
|
||||||
"""Create random segments based on the input lengths.
|
"""Create random segments based on the input lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -110,7 +112,9 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=
|
||||||
_x_lenghts[len_diff < 0] = segment_size
|
_x_lenghts[len_diff < 0] = segment_size
|
||||||
len_diff = _x_lenghts - segment_size + 1
|
len_diff = _x_lenghts - segment_size + 1
|
||||||
else:
|
else:
|
||||||
assert all(len_diff > 0), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
assert all(
|
||||||
|
len_diff > 0
|
||||||
|
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||||
ret = segment(x, segment_indices, segment_size)
|
ret = segment(x, segment_indices, segment_size)
|
||||||
return ret, segment_indices
|
return ret, segment_indices
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from TTS.config import check_config_and_model_args
|
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -9,6 +8,8 @@ import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
|
||||||
|
from TTS.config import check_config_and_model_args
|
||||||
|
|
||||||
|
|
||||||
class LanguageManager:
|
class LanguageManager:
|
||||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||||
|
|
|
@ -104,7 +104,7 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
x = np.array(range(len(chars)))
|
x = np.array(range(len(chars)))
|
||||||
my_xticks = [c for c in chars]
|
my_xticks = chars
|
||||||
plt.xticks(x, my_xticks)
|
plt.xticks(x, my_xticks)
|
||||||
|
|
||||||
ax.set_xlabel("characters")
|
ax.set_xlabel("characters")
|
||||||
|
|
|
@ -5,10 +5,8 @@ import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args_with_default, load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
|
|
@ -306,7 +306,7 @@ class GAN(BaseVocoder):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
return {"input": x, "waveform": y}
|
return {"input": x, "waveform": y}
|
||||||
|
|
||||||
def get_data_loader( # pylint: disable=no-self-use
|
def get_data_loader( # pylint: disable=no-self-use, unused-argument
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
assets: Dict,
|
assets: Dict,
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
max_text_len=c.max_text_len,
|
max_text_len=c.max_text_len,
|
||||||
min_audio_len=c.min_audio_len,
|
min_audio_len=c.min_audio_len,
|
||||||
max_audio_len=c.max_audio_len,
|
max_audio_len=c.max_audio_len,
|
||||||
start_by_longest=start_by_longest
|
start_by_longest=start_by_longest,
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, segment, sequence_mask, rand_segments
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
|
||||||
|
|
||||||
|
|
||||||
def average_over_durations_test(): # pylint: disable=no-self-use
|
def average_over_durations_test(): # pylint: disable=no-self-use
|
||||||
|
|
Loading…
Reference in New Issue