mirror of https://github.com/coqui-ai/TTS.git
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
97de55595f
|
@ -42,16 +42,18 @@ jobs:
|
|||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- run: |
|
||||
- name: Install pip requirements
|
||||
run: |
|
||||
python -m pip install -U pip setuptools wheel build
|
||||
- run: |
|
||||
python -m build
|
||||
- run: |
|
||||
python -m pip install dist/*.whl
|
||||
python -m pip install -r requirements.txt
|
||||
- name: Setup and install manylinux1_x86_64 wheel
|
||||
run: |
|
||||
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
|
||||
python -m pip install dist/*-manylinux*.whl
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: wheel-${{ matrix.python-version }}
|
||||
path: dist/*.whl
|
||||
path: dist/*-manylinux*.whl
|
||||
publish-artifacts:
|
||||
runs-on: ubuntu-20.04
|
||||
needs: [build-sdist, build-wheels]
|
||||
|
|
|
@ -11,4 +11,5 @@ recursive-include TTS *.md
|
|||
recursive-include TTS *.py
|
||||
recursive-include TTS *.pyx
|
||||
recursive-include images *.png
|
||||
|
||||
recursive-exclude tests *
|
||||
prune tests*
|
||||
|
|
46
README.md
46
README.md
|
@ -130,7 +130,7 @@ pip install -e .[all,dev,notebooks] # Select the relevant extras
|
|||
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
||||
|
||||
```bash
|
||||
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a diffent OS.
|
||||
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
|
||||
$ make install
|
||||
```
|
||||
|
||||
|
@ -145,25 +145,61 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
|||
```
|
||||
$ tts --list_models
|
||||
```
|
||||
|
||||
- Get model info (for both tts_models and vocoder_models):
|
||||
- Query by type/name:
|
||||
The model_info_by_name uses the name as it from the --list_models.
|
||||
```
|
||||
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||
```
|
||||
```
|
||||
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||
```
|
||||
- Query by type/idx:
|
||||
The model_query_idx uses the corresponding idx from --list_models.
|
||||
```
|
||||
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --model_info_by_idx tts_models/3
|
||||
```
|
||||
|
||||
- Run TTS with default models:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS"
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||
```
|
||||
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
||||
```
|
||||
|
|
|
@ -215,6 +215,14 @@
|
|||
"author": "@thorstenMueller",
|
||||
"license": "apache 2.0",
|
||||
"commit": "unknown"
|
||||
},
|
||||
"tacotron2-DDC": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--thorsten--tacotron2-DDC.zip",
|
||||
"default_vocoder": "vocoder_models/de/thorsten/hifigan_v1",
|
||||
"description": "Thorsten-Dec2021-22k-DDC",
|
||||
"author": "@thorstenMueller",
|
||||
"license": "apache 2.0",
|
||||
"commit": "unknown"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -460,6 +468,13 @@
|
|||
"author": "@thorstenMueller",
|
||||
"license": "apache 2.0",
|
||||
"commit": "unknown"
|
||||
},
|
||||
"hifigan_v1": {
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/vocoder_models--de--thorsten--hifigan_v1.zip",
|
||||
"description": "HifiGAN vocoder model for Thorsten Neutral Dec2021 22k Samplerate Tacotron2 DDC model",
|
||||
"author": "@thorstenMueller",
|
||||
"license": "apache 2.0",
|
||||
"commit": "unknown"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.7.1
|
||||
0.8.0
|
|
@ -60,13 +60,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
- Run a TTS model with its default vocoder model:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>"
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
|
||||
```
|
||||
|
||||
- Run with specific TTS and vocoder models from the list:
|
||||
|
||||
```
|
||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
||||
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
|
||||
```
|
||||
|
||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||
|
|
|
@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
|
|||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.samplers import PerfectBatchSampler
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
||||
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
||||
import argparse
|
||||
from itertools import product as cartesian_product
|
||||
|
||||
|
@ -7,94 +7,97 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||
parser.add_argument(
|
||||
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for."
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.")
|
||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||
parser.add_argument(
|
||||
"--search_depth",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||
parser.add_argument(
|
||||
"--num_iter",
|
||||
type=int,
|
||||
help="Number of model inference iterations that you like to optimize noise schedule for.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
|
||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||
parser.add_argument(
|
||||
"--search_depth",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||
)
|
||||
|
||||
# load config
|
||||
args = parser.parse_args()
|
||||
config = load_config(args.config_path)
|
||||
# load config
|
||||
args = parser.parse_args()
|
||||
config = load_config(args.config_path)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# load dataset
|
||||
_, train_data = load_wav_data(args.data_path, 0)
|
||||
train_data = train_data[: args.num_samples]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=train_data,
|
||||
seq_len=-1,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_full_clips,
|
||||
drop_last=False,
|
||||
num_workers=config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
# load dataset
|
||||
_, train_data = load_wav_data(args.data_path, 0)
|
||||
train_data = train_data[: args.num_samples]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=train_data,
|
||||
seq_len=-1,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_full_clips,
|
||||
drop_last=False,
|
||||
num_workers=config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
# setup the model
|
||||
model = setup_generator(config)
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
# setup the model
|
||||
model = setup_model(config)
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
|
||||
# setup optimization parameters
|
||||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||
print(base_values)
|
||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||
best_error = float("inf")
|
||||
best_schedule = None
|
||||
total_search_iter = len(base_values) ** args.num_iter
|
||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||
beta = exponents * base
|
||||
model.compute_noise_level(beta)
|
||||
for data in loader:
|
||||
mel, audio = data
|
||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||
# setup optimization parameters
|
||||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||
print(f" > base values: {base_values}")
|
||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||
best_error = float("inf")
|
||||
best_schedule = None # pylint: disable=C0103
|
||||
total_search_iter = len(base_values) ** args.num_iter
|
||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||
beta = exponents * base
|
||||
model.compute_noise_level(beta)
|
||||
for data in loader:
|
||||
mel, audio = data
|
||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||
|
||||
if args.use_cuda:
|
||||
y_hat = y_hat.cpu()
|
||||
y_hat = y_hat.numpy()
|
||||
if args.use_cuda:
|
||||
y_hat = y_hat.cpu()
|
||||
y_hat = y_hat.numpy()
|
||||
|
||||
mel_hat = []
|
||||
for i in range(y_hat.shape[0]):
|
||||
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||
mel_hat.append(torch.from_numpy(m))
|
||||
mel_hat = []
|
||||
for i in range(y_hat.shape[0]):
|
||||
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||
mel_hat.append(torch.from_numpy(m))
|
||||
|
||||
mel_hat = torch.stack(mel_hat)
|
||||
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||
if mse.item() < best_error:
|
||||
best_error = mse.item()
|
||||
best_schedule = {"beta": beta}
|
||||
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||
np.save(args.output_path, best_schedule)
|
||||
mel_hat = torch.stack(mel_hat)
|
||||
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||
if mse.item() < best_error:
|
||||
best_error = mse.item()
|
||||
best_schedule = {"beta": beta}
|
||||
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||
np.save(args.output_path, best_schedule)
|
||||
|
|
|
@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str:
|
|||
return model_name
|
||||
|
||||
|
||||
def load_config(config_path: str) -> None:
|
||||
def load_config(config_path: str) -> Coqpit:
|
||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||
to find the corresponding Config class. Then initialize the Config.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
# from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
|
||||
|
|
|
@ -200,9 +200,6 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
loss_masking (bool):
|
||||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
sort_by_audio_len (bool):
|
||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
||||
|
||||
min_text_len (int):
|
||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
|
@ -303,7 +300,6 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
sort_by_audio_len: bool = False
|
||||
min_audio_len: int = 1
|
||||
max_audio_len: int = float("inf")
|
||||
min_text_len: int = 1
|
||||
|
|
|
@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
||||
stopnet_pos_weight (float):
|
||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||
datasets with longer sentences. Defaults to 10.
|
||||
datasets with longer sentences. Defaults to 0.2.
|
||||
max_decoder_steps (int):
|
||||
Max number of steps allowed for the decoder. Defaults to 50.
|
||||
encoder_in_features (int):
|
||||
|
@ -161,8 +161,8 @@ class TacotronConfig(BaseTTSConfig):
|
|||
prenet_dropout_at_inference: bool = False
|
||||
stopnet: bool = True
|
||||
separate_stopnet: bool = True
|
||||
stopnet_pos_weight: float = 10.0
|
||||
max_decoder_steps: int = 500
|
||||
stopnet_pos_weight: float = 0.2
|
||||
max_decoder_steps: int = 10000
|
||||
encoder_in_features: int = 256
|
||||
decoder_in_features: int = 256
|
||||
decoder_output_dim: int = 80
|
||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.vits import VitsArgs
|
||||
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
|
|||
model_args (VitsArgs):
|
||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||
|
||||
audio (VitsAudioConfig):
|
||||
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||
|
||||
|
@ -67,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
use_weighted_sampler (bool):
|
||||
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||
|
||||
weighted_sampler_attrs (dict):
|
||||
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||
|
||||
weighted_sampler_multipliers (dict):
|
||||
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
|
@ -94,6 +109,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
model: str = "vits"
|
||||
# model specific params
|
||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||
audio: VitsAudioConfig = VitsAudioConfig()
|
||||
|
||||
# optimizer
|
||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||
|
@ -120,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
|
|||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# sampler params
|
||||
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
# overrides
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
|
|
@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|||
"audio_file": audio_path,
|
||||
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
|
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
wav_file = cols[1].strip()
|
||||
text = cols[0].strip()
|
||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
|||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
|
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
|||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
||||
items.append(
|
||||
{
|
||||
"text": text,
|
||||
"audio_file": wav_file,
|
||||
"speaker_name": f"LTTS_{speaker_name}",
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
for item in items:
|
||||
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
||||
return items
|
||||
|
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
skipped_files.append(wav_file)
|
||||
continue
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
|||
else:
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||
if os.path.exists(wav_file):
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
return items
|
||||
|
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
|
|||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
|
|||
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
||||
with open(txt_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
|
|||
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
||||
continue
|
||||
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
||||
items.append(
|
||||
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
|||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -554,7 +576,7 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2].replace(" ", "")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
|
|
@ -67,9 +67,14 @@ class WN(torch.nn.Module):
|
|||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
if i == 0:
|
||||
in_layer = torch.nn.Conv1d(
|
||||
in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
else:
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
|
|
|
@ -29,11 +29,11 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
|
||||
def unsqueeze(x, x_mask=None, num_sqz=2):
|
||||
"""GlowTTS unsqueeze operation
|
||||
"""GlowTTS unsqueeze operation (revert the squeeze)
|
||||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
|
||||
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]``
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
|
||||
|
|
|
@ -197,7 +197,7 @@ class CouplingBlock(nn.Module):
|
|||
end.bias.data.zero_()
|
||||
self.end = end
|
||||
# coupling layers
|
||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
|
||||
self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
|
|
@ -7,8 +7,8 @@ from torch import nn
|
|||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
|
@ -91,30 +91,55 @@ class MSELossMasked(nn.Module):
|
|||
return loss
|
||||
|
||||
|
||||
def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Min-Max normalize tensor through first dimension
|
||||
Shapes:
|
||||
- x: :math:`[B, D1, D2]`
|
||||
- m: :math:`[B, D1, 1]`
|
||||
"""
|
||||
maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True)
|
||||
minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True)
|
||||
return (x - minimum) / (maximum - minimum + 1e-8)
|
||||
|
||||
|
||||
class SSIMLoss(torch.nn.Module):
|
||||
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
|
||||
"""SSIM loss as (1 - SSIM)
|
||||
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_func = ssim
|
||||
self.loss_func = _SSIMLoss()
|
||||
|
||||
def forward(self, y_hat, y, length=None):
|
||||
def forward(self, y_hat, y, length):
|
||||
"""
|
||||
Args:
|
||||
y_hat (tensor): model prediction values.
|
||||
y (tensor): target values.
|
||||
length (tensor): length of each sample in a batch.
|
||||
length (tensor): length of each sample in a batch for masking.
|
||||
|
||||
Shapes:
|
||||
y_hat: B x T X D
|
||||
y: B x T x D
|
||||
length: B
|
||||
|
||||
Returns:
|
||||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
if length is not None:
|
||||
m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
|
||||
y_hat, y = y_hat * m, y * m
|
||||
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2)
|
||||
y_norm = sample_wise_min_max(y, mask)
|
||||
y_hat_norm = sample_wise_min_max(y_hat, mask)
|
||||
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
|
||||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||
|
||||
return ssim_loss
|
||||
|
||||
|
||||
class AttentionEntropyLoss(nn.Module):
|
||||
|
@ -123,9 +148,6 @@ class AttentionEntropyLoss(nn.Module):
|
|||
"""
|
||||
Forces attention to be more decisive by penalizing
|
||||
soft attention weights
|
||||
|
||||
TODO: arguments
|
||||
TODO: unit_test
|
||||
"""
|
||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||
loss = (entropy / np.log(align.shape[1])).mean()
|
||||
|
@ -133,9 +155,17 @@ class AttentionEntropyLoss(nn.Module):
|
|||
|
||||
|
||||
class BCELossMasked(nn.Module):
|
||||
def __init__(self, pos_weight):
|
||||
"""BCE loss with masking.
|
||||
|
||||
Used mainly for stopnet in autoregressive models.
|
||||
|
||||
Args:
|
||||
pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, pos_weight: float = None):
|
||||
super().__init__()
|
||||
self.pos_weight = pos_weight
|
||||
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
|
@ -155,16 +185,17 @@ class BCELossMasked(nn.Module):
|
|||
Returns:
|
||||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
if length is not None:
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||
x = x * mask
|
||||
target = target * mask
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
num_items = mask.sum()
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
|
||||
)
|
||||
else:
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
num_items = torch.numel(x)
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
loss = loss / num_items
|
||||
return loss
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ class CapacitronVAE(nn.Module):
|
|||
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
|
||||
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = torch.squeeze(speaker_embedding)
|
||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||
|
||||
# Feed the output of the ref encoder and information about text/speaker into
|
||||
|
|
|
@ -137,7 +137,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
|
|
|
@ -514,7 +514,7 @@ class GlowTTS(BaseTTS):
|
|||
y = y[:, :, :y_max_length]
|
||||
if attn is not None:
|
||||
attn = attn[:, :, :, :y_max_length]
|
||||
y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
|
||||
y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze
|
||||
return y, y_lengths, y_max_length, attn
|
||||
|
||||
def store_inverse(self):
|
||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
|
|||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
|
@ -13,6 +14,8 @@ from torch import nn
|
|||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.samplers import BucketBatchSampler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
@ -200,11 +205,51 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
return spec
|
||||
|
||||
|
||||
#############################
|
||||
# CONFIGS
|
||||
#############################
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitsAudioConfig(Coqpit):
|
||||
fft_size: int = 1024
|
||||
sample_rate: int = 22050
|
||||
win_length: int = 1024
|
||||
hop_length: int = 256
|
||||
num_mels: int = 80
|
||||
mel_fmin: int = 0
|
||||
mel_fmax: int = None
|
||||
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -786,7 +831,7 @@ class Vits(BaseTTS):
|
|||
print(" > Text Encoder was reinit.")
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
def _freeze_layers(self):
|
||||
|
@ -817,7 +862,7 @@ class Vits(BaseTTS):
|
|||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g, lid = None, None, None
|
||||
sid, g, lid, durations = None, None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
|
@ -832,7 +877,10 @@ class Vits(BaseTTS):
|
|||
if lid.ndim == 0:
|
||||
lid = lid.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid
|
||||
if "durations" in aux_input and aux_input["durations"] is not None:
|
||||
durations = aux_input["durations"]
|
||||
|
||||
return sid, g, lid, durations
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
|
@ -946,7 +994,7 @@ class Vits(BaseTTS):
|
|||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -1028,7 +1076,9 @@ class Vits(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
self,
|
||||
x,
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1048,7 +1098,7 @@ class Vits(BaseTTS):
|
|||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
# speaker embedding
|
||||
|
@ -1062,21 +1112,25 @@ class Vits(BaseTTS):
|
|||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g if self.args.condition_dp_on_speaker else None,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb,
|
||||
)
|
||||
if durations is None:
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g if self.args.condition_dp_on_speaker else None,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb,
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
assert durations.shape[-1] == x.shape[-1]
|
||||
w = durations.unsqueeze(0)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||
|
@ -1341,7 +1395,7 @@ class Vits(BaseTTS):
|
|||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
elif config.use_speaker_embedding:
|
||||
|
@ -1485,6 +1539,42 @@ class Vits(BaseTTS):
|
|||
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
||||
return batch
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
if getattr(config, "use_weighted_sampler", False):
|
||||
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||
print(multi_dict)
|
||||
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||
)
|
||||
weights = weights * alpha
|
||||
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||
|
||||
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
||||
|
||||
if weights is not None:
|
||||
w_sampler = WeightedRandomSampler(weights, len(weights))
|
||||
batch_sampler = BucketBatchSampler(
|
||||
w_sampler,
|
||||
data=data_items,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
sort_key=lambda x: os.path.getsize(x["audio_file"]),
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
batch_sampler = None
|
||||
# sampler for DDP
|
||||
if batch_sampler is None:
|
||||
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||
batch_sampler = (
|
||||
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
|
||||
) # TODO: check batch_sampler with multi-gpu
|
||||
return batch_sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
|
@ -1523,17 +1613,24 @@ class Vits(BaseTTS):
|
|||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
if sampler is None:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
else:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
|
@ -1590,7 +1687,7 @@ class Vits(BaseTTS):
|
|||
strict=True,
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
|
|
|
@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
|||
index_start = segment_indices[i]
|
||||
index_end = index_start + segment_size
|
||||
x_i = x[i]
|
||||
if pad_short and index_end > x.size(2):
|
||||
if pad_short and index_end >= x.size(2):
|
||||
# pad the sample if it is shorter than the segment size
|
||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||
segments[i] = x_i[:, index_start:index_end]
|
||||
|
@ -107,16 +107,16 @@ def rand_segments(
|
|||
T = segment_size
|
||||
if _x_lenghts is None:
|
||||
_x_lenghts = T
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
len_diff = _x_lenghts - segment_size
|
||||
if let_short_samples:
|
||||
_x_lenghts[len_diff < 0] = segment_size
|
||||
len_diff = _x_lenghts - segment_size + 1
|
||||
len_diff = _x_lenghts - segment_size
|
||||
else:
|
||||
assert all(
|
||||
len_diff > 0
|
||||
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
||||
ret = segment(x, segment_indices, segment_size)
|
||||
segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long()
|
||||
ret = segment(x, segment_indices, segment_size, pad_short=pad_short)
|
||||
return ret, segment_indices
|
||||
|
||||
|
||||
|
|
|
@ -1,73 +1,383 @@
|
|||
# taken from https://github.com/Po-Hsun-Su/pytorch-ssim
|
||||
# Adopted from https://github.com/photosynthesis-team/piq
|
||||
|
||||
from math import exp
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
||||
r"""Reduce input in batch dimension if needed.
|
||||
Args:
|
||||
x: Tensor with shape (N, *).
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
||||
"""
|
||||
if reduction == "none":
|
||||
return x
|
||||
if reduction == "mean":
|
||||
return x.mean(dim=0)
|
||||
if reduction == "sum":
|
||||
return x.sum(dim=0)
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
def _validate_input(
|
||||
tensors: List[torch.Tensor],
|
||||
dim_range: Tuple[int, int] = (0, -1),
|
||||
data_range: Tuple[float, float] = (0.0, -1.0),
|
||||
# size_dim_range: Tuple[float, float] = (0., -1.),
|
||||
size_range: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
r"""Check that input(-s) satisfies the requirements
|
||||
Args:
|
||||
tensors: Tensors to check
|
||||
dim_range: Allowed number of dimensions. (min, max)
|
||||
data_range: Allowed range of values in tensors. (min, max)
|
||||
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
|
||||
"""
|
||||
|
||||
if not __debug__:
|
||||
return
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
x = tensors[0]
|
||||
|
||||
# TODO: check if you need AMP disabled
|
||||
# with torch.cuda.amp.autocast(enabled=False):
|
||||
mu1_sq = mu1.float().pow(2)
|
||||
mu2_sq = mu2.float().pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
for t in tensors:
|
||||
assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}"
|
||||
assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}"
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
if size_range is None:
|
||||
assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}"
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
window = window.type_as(img1)
|
||||
assert (
|
||||
t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]]
|
||||
), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
if dim_range[0] == dim_range[1]:
|
||||
assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
|
||||
elif dim_range[0] < dim_range[1]:
|
||||
assert (
|
||||
dim_range[0] <= t.dim() <= dim_range[1]
|
||||
), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
||||
if data_range[0] < data_range[1]:
|
||||
assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
|
||||
assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}"
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel).type_as(img1)
|
||||
window = window.type_as(img1)
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||
def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
|
||||
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
|
||||
Args:
|
||||
size: Size of the kernel
|
||||
sigma: Std of the distribution
|
||||
Returns:
|
||||
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
||||
"""
|
||||
coords = torch.arange(kernel_size, dtype=torch.float32)
|
||||
coords -= (kernel_size - 1) / 2.0
|
||||
|
||||
g = coords**2
|
||||
g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp()
|
||||
|
||||
g /= g.sum()
|
||||
return g.unsqueeze(0)
|
||||
|
||||
|
||||
def ssim(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel_size: int = 11,
|
||||
kernel_sigma: float = 1.5,
|
||||
data_range: Union[int, float] = 1.0,
|
||||
reduction: str = "mean",
|
||||
full: bool = False,
|
||||
downsample: bool = True,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> List[torch.Tensor]:
|
||||
r"""Interface of Structural Similarity (SSIM) index.
|
||||
Inputs supposed to be in range ``[0, data_range]``.
|
||||
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
|
||||
kernel_sigma: Sigma of normal distribution.
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||
full: Return cs map or not.
|
||||
downsample: Perform average pool before SSIM computation. Default: True
|
||||
k1: Algorithm parameter, K1 (small constant).
|
||||
k2: Algorithm parameter, K2 (small constant).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
|
||||
as a tensor of size 2.
|
||||
|
||||
References:
|
||||
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||
Image quality assessment: From error visibility to structural similarity.
|
||||
IEEE Transactions on Image Processing, 13, 600-612.
|
||||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI: `10.1109/TIP.2003.819861`
|
||||
"""
|
||||
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||
_validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))
|
||||
|
||||
x = x / float(data_range)
|
||||
y = y / float(data_range)
|
||||
|
||||
# Averagepool image if the size is large enough
|
||||
f = max(1, round(min(x.size()[-2:]) / 256))
|
||||
if (f > 1) and downsample:
|
||||
x = F.avg_pool2d(x, kernel_size=f)
|
||||
y = F.avg_pool2d(y, kernel_size=f)
|
||||
|
||||
kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
|
||||
_compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
|
||||
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2)
|
||||
ssim_val = ssim_map.mean(1)
|
||||
cs = cs_map.mean(1)
|
||||
|
||||
ssim_val = _reduce(ssim_val, reduction)
|
||||
cs = _reduce(cs, reduction)
|
||||
|
||||
if full:
|
||||
return [ssim_val, cs]
|
||||
|
||||
return ssim_val
|
||||
|
||||
|
||||
class SSIMLoss(_Loss):
|
||||
r"""Creates a criterion that measures the structural similarity index error between
|
||||
each element in the input :math:`x` and target :math:`y`.
|
||||
|
||||
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||
|
||||
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
|
||||
|
||||
.. math::
|
||||
SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\
|
||||
ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)}
|
||||
{(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)},
|
||||
|
||||
where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'``
|
||||
(default ``'mean'``), then:
|
||||
|
||||
.. math::
|
||||
SSIMLoss(x, y) =
|
||||
\begin{cases}
|
||||
\operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\
|
||||
\operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.}
|
||||
\end{cases}
|
||||
|
||||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
|
||||
of :math:`n` elements each.
|
||||
|
||||
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
|
||||
In case of 5D input tensors, complex value is returned as a tensor of size 2.
|
||||
|
||||
Args:
|
||||
kernel_size: By default, the mean and covariance of a pixel is obtained
|
||||
by convolution with given filter_size.
|
||||
kernel_sigma: Standard deviation for Gaussian kernel.
|
||||
k1: Coefficient related to c1 in the above equation.
|
||||
k2: Coefficient related to c2 in the above equation.
|
||||
downsample: Perform average pool before SSIM computation. Default: True
|
||||
reduction: Specifies the reduction type:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||
|
||||
Examples:
|
||||
>>> loss = SSIMLoss()
|
||||
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
||||
>>> y = torch.rand(3, 3, 256, 256)
|
||||
>>> output = loss(x, y)
|
||||
>>> output.backward()
|
||||
|
||||
References:
|
||||
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||
Image quality assessment: From error visibility to structural similarity.
|
||||
IEEE Transactions on Image Processing, 13, 600-612.
|
||||
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||
DOI:`10.1109/TIP.2003.819861`
|
||||
"""
|
||||
__constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int = 11,
|
||||
kernel_sigma: float = 1.5,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
downsample: bool = True,
|
||||
reduction: str = "mean",
|
||||
data_range: Union[int, float] = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Generic loss parameters.
|
||||
self.reduction = reduction
|
||||
|
||||
# Loss-specific parameters.
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# This check might look redundant because kernel size is checked within the ssim function anyway.
|
||||
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
|
||||
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||
self.kernel_sigma = kernel_sigma
|
||||
self.k1 = k1
|
||||
self.k2 = k2
|
||||
self.downsample = downsample
|
||||
self.data_range = data_range
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
r"""Computation of Structural Similarity (SSIM) index as a loss function.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||
|
||||
Returns:
|
||||
Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors,
|
||||
complex value is returned as a tensor of size 2.
|
||||
"""
|
||||
|
||||
score = ssim(
|
||||
x=x,
|
||||
y=y,
|
||||
kernel_size=self.kernel_size,
|
||||
kernel_sigma=self.kernel_sigma,
|
||||
downsample=self.downsample,
|
||||
data_range=self.data_range,
|
||||
reduction=self.reduction,
|
||||
full=False,
|
||||
k1=self.k1,
|
||||
k2=self.k2,
|
||||
)
|
||||
return torch.ones_like(score) - score
|
||||
|
||||
|
||||
def _ssim_per_channel(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
||||
kernel: 2D Gaussian kernel.
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Full Value of Structural Similarity (SSIM) index.
|
||||
"""
|
||||
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
|
||||
raise ValueError(
|
||||
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||
f"Kernel size: {kernel.size()}"
|
||||
)
|
||||
|
||||
c1 = k1**2
|
||||
c2 = k2**2
|
||||
n_channels = x.size(1)
|
||||
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
|
||||
mu_xx = mu_x**2
|
||||
mu_yy = mu_y**2
|
||||
mu_xy = mu_x * mu_y
|
||||
|
||||
sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
|
||||
sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
|
||||
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
|
||||
|
||||
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
|
||||
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
|
||||
|
||||
# Structural similarity (SSIM)
|
||||
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
|
||||
|
||||
ssim_val = ss.mean(dim=(-1, -2))
|
||||
cs = cs.mean(dim=(-1, -2))
|
||||
return ssim_val, cs
|
||||
|
||||
|
||||
def _ssim_per_channel_complex(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
|
||||
|
||||
Args:
|
||||
x: An input tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
y: A target tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||
kernel: 2-D gauss kernel.
|
||||
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||
|
||||
Returns:
|
||||
Full Value of Complex Structural Similarity (SSIM) index.
|
||||
"""
|
||||
n_channels = x.size(1)
|
||||
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
|
||||
raise ValueError(
|
||||
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||
f"Kernel size: {kernel.size()}"
|
||||
)
|
||||
|
||||
c1 = k1**2
|
||||
c2 = k2**2
|
||||
|
||||
x_real = x[..., 0]
|
||||
x_imag = x[..., 1]
|
||||
y_real = y[..., 0]
|
||||
y_imag = y[..., 1]
|
||||
|
||||
mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||
|
||||
mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2)
|
||||
mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2)
|
||||
mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag
|
||||
mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real
|
||||
|
||||
compensation = 1.0
|
||||
|
||||
x_sq = x_real.pow(2) + x_imag.pow(2)
|
||||
y_sq = y_real.pow(2) + y_imag.pow(2)
|
||||
x_y_real = x_real * y_real - x_imag * y_imag
|
||||
x_y_imag = x_real * y_imag + x_imag * y_real
|
||||
|
||||
sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq
|
||||
sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq
|
||||
sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real
|
||||
sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag
|
||||
sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1)
|
||||
mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1)
|
||||
# Set alpha = beta = gamma = 1.
|
||||
cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation)
|
||||
ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation)
|
||||
ssim_map = ssim_map * cs_map
|
||||
|
||||
ssim_val = ssim_map.mean(dim=(-2, -3))
|
||||
cs = cs_map.mean(dim=(-2, -3))
|
||||
|
||||
return ssim_val, cs
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
from typing import Dict, List
|
||||
|
||||
|
@ -163,6 +164,13 @@ class ESpeak(BasePhonemizer):
|
|||
|
||||
# dealing with the conditions descrived above
|
||||
ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:]
|
||||
|
||||
# espeak-ng backend can add language flags that need to be removed:
|
||||
# "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ."
|
||||
# phonemize needs to remove the language flags of the returned text:
|
||||
# "sɛʁtˈɛ̃ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ."
|
||||
ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded)
|
||||
|
||||
phonemes += ph_decoded.strip()
|
||||
return phonemes.replace("_", separator)
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ class Punctuation:
|
|||
|
||||
# nothing have been phonemized, returns the puncs alone
|
||||
if not text:
|
||||
return ["".join(m.mark for m in puncs)]
|
||||
return ["".join(m.punc for m in puncs)]
|
||||
|
||||
current = puncs[0]
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from TTS.utils.audio.processor import AudioProcessor
|
|
@ -0,0 +1,425 @@
|
|||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pyworld as pw
|
||||
import scipy
|
||||
import soundfile as sf
|
||||
|
||||
# For using kwargs
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
def build_mel_basis(
|
||||
*,
|
||||
sample_rate: int = None,
|
||||
fft_size: int = None,
|
||||
num_mels: int = None,
|
||||
mel_fmax: int = None,
|
||||
mel_fmin: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if mel_fmax is not None:
|
||||
assert mel_fmax <= sample_rate // 2
|
||||
assert mel_fmax - mel_fmin > 0
|
||||
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
|
||||
|
||||
|
||||
def millisec_to_length(
|
||||
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute hop and window length from milliseconds.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = frame_length_ms / frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
win_length = int(frame_length_ms / 1000.0 * sample_rate)
|
||||
hop_length = int(win_length / float(factor))
|
||||
return win_length, hop_length
|
||||
|
||||
|
||||
def _log(x, base):
|
||||
if base == 10:
|
||||
return np.log10(x)
|
||||
return np.log(x)
|
||||
|
||||
|
||||
def _exp(x, base):
|
||||
if base == 10:
|
||||
return np.power(10, x)
|
||||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
return gain * _log(np.maximum(1e-8, x), base)
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
gain (float): Gain factor. Defaults to 1.
|
||||
base (int): Logarithm base. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / gain, base)
|
||||
|
||||
|
||||
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -coef], [1], x)
|
||||
|
||||
|
||||
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if coef == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -coef], x)
|
||||
|
||||
|
||||
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Shapes:
|
||||
- spec: :math:`[C, T]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
return np.dot(mel_basis, spec)
|
||||
|
||||
|
||||
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
|
||||
|
||||
|
||||
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
|
||||
"""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = np.abs(D)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
D = stft(y=wav, **kwargs)
|
||||
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
|
||||
return S.astype(np.float32)
|
||||
|
||||
|
||||
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = spec.copy()
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = mel.copy()
|
||||
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||
return griffin_lim(spec=S**power, **kwargs)
|
||||
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def stft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
pad_mode: str = "reflect",
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=fft_size,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
pad_mode=pad_mode,
|
||||
window=window,
|
||||
center=center,
|
||||
)
|
||||
|
||||
|
||||
def istft(
|
||||
*,
|
||||
y: np.ndarray = None,
|
||||
fft_size: int = None,
|
||||
hop_length: int = None,
|
||||
win_length: int = None,
|
||||
window: str = "hann",
|
||||
center: bool = True,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper.
|
||||
|
||||
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
|
||||
|
||||
|
||||
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
|
||||
S_complex = np.abs(spec).astype(np.complex)
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(num_iter):
|
||||
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
return y
|
||||
|
||||
|
||||
def compute_stft_paddings(
|
||||
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
|
||||
if not pad_two_sides:
|
||||
return 0, pad
|
||||
return pad // 2, pad // 2 + pad % 2
|
||||
|
||||
|
||||
def compute_f0(
|
||||
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
|
||||
) -> np.ndarray:
|
||||
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
|
||||
Returns:
|
||||
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
||||
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
||||
>>> pitch = ap.compute_f0(wav)
|
||||
"""
|
||||
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||
|
||||
f0, t = pw.dio(
|
||||
x.astype(np.double),
|
||||
fs=sample_rate,
|
||||
f0_ceil=pitch_fmax,
|
||||
frame_period=1000 * hop_length / sample_rate,
|
||||
)
|
||||
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
||||
return f0
|
||||
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
trim_db: float = -40,
|
||||
sample_rate: int = None,
|
||||
min_silence_sec=0.8,
|
||||
gain: float = None,
|
||||
base: int = None,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
|
||||
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
|
||||
def trim_silence(
|
||||
*,
|
||||
wav: np.ndarray = None,
|
||||
sample_rate: int = None,
|
||||
trim_db: float = None,
|
||||
win_length: int = None,
|
||||
hop_length: int = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||
margin = int(sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
|
||||
|
||||
|
||||
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * coef
|
||||
|
||||
|
||||
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
r = 10 ** (db_level / 20)
|
||||
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||
return wav * a
|
||||
|
||||
|
||||
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||
"""Normalize the volume based on RMS of the signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
db_level (float): Target dB level in RMS. Defaults to -27.0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: RMS normalized waveform.
|
||||
"""
|
||||
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||
wav = rms_norm(wav=x, db_level=db_level)
|
||||
return wav
|
||||
|
||||
|
||||
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if resample:
|
||||
# loading with resampling. It is significantly slower.
|
||||
x, _ = librosa.load(filename, sr=sample_rate)
|
||||
else:
|
||||
# SF is faster than librosa for loading files
|
||||
x, _ = sf.read(filename)
|
||||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
|
||||
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
mu = 2**mulaw_qc - 1
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
|
||||
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||
"""Recovers waveform from quantized values."""
|
||||
mu = 2**mulaw_qc - 1
|
||||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
|
||||
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
|
||||
|
||||
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
quantize_bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2**quantize_bits - 1) / 2
|
||||
|
||||
|
||||
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2**quantize_bits - 1) - 1
|
|
@ -6,179 +6,14 @@ import pyworld as pw
|
|||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.helpers import StandardScaler
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
TODO: Merge this with audio.py
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS used by all the data pipelines.
|
||||
|
||||
TODO: Make this a dataclass to replace `BaseAudioConfig`.
|
||||
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS.
|
||||
|
||||
Note:
|
||||
All the class arguments are set to default values to enable a flexible initialization
|
|
@ -0,0 +1,163 @@
|
|||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
||||
Args:
|
||||
|
||||
n_fft (int):
|
||||
FFT window size for STFT.
|
||||
|
||||
hop_length (int):
|
||||
number of frames between STFT columns.
|
||||
|
||||
win_length (int, optional):
|
||||
STFT window length.
|
||||
|
||||
pad_wav (bool, optional):
|
||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||
|
||||
window (str, optional):
|
||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||
|
||||
sample_rate (int, optional):
|
||||
target audio sampling rate. Defaults to None.
|
||||
|
||||
mel_fmin (int, optional):
|
||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
mel_fmax (int, optional):
|
||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||
|
||||
n_mels (int, optional):
|
||||
number of melspectrogram dimensions. Defaults to None.
|
||||
|
||||
use_mel (bool, optional):
|
||||
If True compute the melspectrograms otherwise. Defaults to False.
|
||||
|
||||
do_amp_to_db_linear (bool, optional):
|
||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||
|
||||
spec_gain (float, optional):
|
||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||
|
||||
power (float, optional):
|
||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||
|
||||
use_htk (bool, optional):
|
||||
Use HTK formula in mel filter instead of Slaney.
|
||||
|
||||
mel_norm (None, 'slaney', or number, optional):
|
||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||
(area normalization).
|
||||
|
||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||
See `librosa.util.normalize` for a full description of supported norm values
|
||||
(including `+-np.inf`).
|
||||
|
||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S**self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax,
|
||||
htk=self.use_htk,
|
||||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
|
@ -34,6 +34,8 @@ class CapacitronOptimizer:
|
|||
self.primary_optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
# Update param groups to display the correct learning rate
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
self.primary_optimizer.step()
|
||||
|
||||
def zero_grad(self):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import io
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
|
@ -7,6 +6,7 @@ from shutil import copyfile, rmtree
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
@ -337,11 +337,20 @@ class ModelManager(object):
|
|||
def _download_zip_file(file_url, output_folder):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url)
|
||||
r = requests.get(file_url, stream=True)
|
||||
# extract the file
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
with zipfile.ZipFile(temp_zip_name) as z:
|
||||
z.extractall(output_folder)
|
||||
os.remove(temp_zip_name) # delete zip after extract
|
||||
except zipfile.BadZipFile:
|
||||
print(f" > Error: Bad zip file - {file_url}")
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import math
|
||||
import random
|
||||
from typing import Callable, List, Union
|
||||
|
||||
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
|
@ -112,3 +114,89 @@ class PerfectBatchSampler(Sampler):
|
|||
def __len__(self):
|
||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class SortedSampler(Sampler):
|
||||
"""Samples elements sequentially, always in the same order.
|
||||
|
||||
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
data (iterable): Iterable data.
|
||||
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||
numerical comparison key from each list element.
|
||||
|
||||
Example:
|
||||
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, sort_key: Callable = identity):
|
||||
super().__init__(data)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||
self.sorted_indexes = [item[0] for item in zip_]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sorted_indexes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class BucketBatchSampler(BatchSampler):
|
||||
"""Bucket batch sampler
|
||||
|
||||
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
sampler (torch.data.utils.sampler.Sampler):
|
||||
batch_size (int): Size of mini-batch.
|
||||
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||
than `batch_size`.
|
||||
data (list): List of data samples.
|
||||
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||
bucket_size_multiplier (int, optional): Buckets are of size
|
||||
`batch_size * bucket_size_multiplier`.
|
||||
|
||||
Example:
|
||||
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler,
|
||||
data,
|
||||
batch_size,
|
||||
drop_last,
|
||||
sort_key: Union[Callable, List] = identity,
|
||||
bucket_size_multiplier=100,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
_bucket_size = batch_size * bucket_size_multiplier
|
||||
if hasattr(sampler, "__len__"):
|
||||
_bucket_size = min(_bucket_size, len(sampler))
|
||||
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||
|
||||
def __iter__(self):
|
||||
for idxs in self.bucket_sampler:
|
||||
bucket_data = [self.data[idx] for idx in idxs]
|
||||
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||
yield sorted_idxs
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
return math.ceil(len(self.sampler) / self.batch_size)
|
|
@ -321,7 +321,7 @@ class Synthesizer(object):
|
|||
waveform = waveform.squeeze()
|
||||
|
||||
# trim silence
|
||||
if self.tts_config.audio["do_trim_silence"] is True:
|
||||
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||
|
||||
wavs += list(waveform)
|
||||
|
|
|
@ -149,4 +149,4 @@ class WaveGradDataset(Dataset):
|
|||
mels[idx, :, : mel.shape[1]] = mel
|
||||
audios[idx, : audio.shape[0]] = audio
|
||||
|
||||
return audios, mels
|
||||
return mels, audios
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
#################################
|
||||
|
|
|
@ -185,8 +185,7 @@ class GAN(BaseVocoder):
|
|||
outputs = {"model_outputs": self.y_hat_g}
|
||||
return outputs, loss_dict
|
||||
|
||||
@staticmethod
|
||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
|
@ -198,7 +197,7 @@ class GAN(BaseVocoder):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"] if outputs[0] is not None else outputs[1]["model_outputs"]
|
||||
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
|
|
@ -233,6 +233,7 @@ class Wavernn(BaseVocoder):
|
|||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||
|
||||
self.ap = AudioProcessor(**config.audio.to_dict())
|
||||
self.aux_dims = self.args.res_out_dims // 4
|
||||
|
||||
if self.args.use_upsample_net:
|
||||
|
@ -571,7 +572,7 @@ class Wavernn(BaseVocoder):
|
|||
def test(
|
||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
ap = assets["audio_processor"]
|
||||
ap = self.ap
|
||||
figures = {}
|
||||
audios = {}
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
|
@ -587,8 +588,16 @@ class Wavernn(BaseVocoder):
|
|||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
# audios.update({f"real_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
def test_log(
|
||||
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
figures, audios = outputs
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch: Dict) -> Dict:
|
||||
waveform = batch[0]
|
||||
|
@ -605,7 +614,7 @@ class Wavernn(BaseVocoder):
|
|||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
ap = assets["audio_processor"]
|
||||
ap = self.ap
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=samples,
|
||||
|
|
|
@ -45,7 +45,7 @@
|
|||
"source": [
|
||||
"NUM_PROC = 8\n",
|
||||
"DATASET_CONFIG = BaseDatasetConfig(\n",
|
||||
" name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/home/ubuntu/TTS/depot/data/male_dataset1_44k/\"\n",
|
||||
" name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/absolute/path/to/your/dataset/\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -58,13 +58,13 @@
|
|||
"def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n",
|
||||
" txt_file = os.path.join(root_path, meta_file)\n",
|
||||
" items = []\n",
|
||||
" speaker_name = \"maledataset1\"\n",
|
||||
" speaker_name = \"myspeaker\"\n",
|
||||
" with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n",
|
||||
" for line in ttf:\n",
|
||||
" cols = line.split(\"|\")\n",
|
||||
" wav_file = os.path.join(root_path, \"wavs\", cols[0])\n",
|
||||
" wav_file = os.path.join(root_path, \"wavs\", cols[0] + \".wav\") \n",
|
||||
" text = cols[1]\n",
|
||||
" items.append([text, wav_file, speaker_name])\n",
|
||||
" items.append({\"text\": text, \"audio_file\": wav_file, \"speaker_name\": speaker_name})\n",
|
||||
" return items"
|
||||
]
|
||||
},
|
||||
|
@ -78,7 +78,10 @@
|
|||
"source": [
|
||||
"# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n",
|
||||
"train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n",
|
||||
"items = train_samples + eval_samples\n",
|
||||
"if eval_samples is not None:\n",
|
||||
" items = train_samples + eval_samples\n",
|
||||
"else:\n",
|
||||
" items = train_samples\n",
|
||||
"print(\" > Number of audio files: {}\".format(len(items)))\n",
|
||||
"print(items[1])"
|
||||
]
|
||||
|
@ -94,7 +97,7 @@
|
|||
"# check wavs if exist\n",
|
||||
"wav_files = []\n",
|
||||
"for item in items:\n",
|
||||
" wav_file = item[1].strip()\n",
|
||||
" wav_file = item[\"audio_file\"].strip()\n",
|
||||
" wav_files.append(wav_file)\n",
|
||||
" if not os.path.exists(wav_file):\n",
|
||||
" print(waf_path)"
|
||||
|
@ -131,8 +134,8 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"def load_item(item):\n",
|
||||
" text = item[0].strip()\n",
|
||||
" file_name = item[1].strip()\n",
|
||||
" text = item[\"text\"].strip()\n",
|
||||
" file_name = item[\"audio_file\"].strip()\n",
|
||||
" audio, sr = librosa.load(file_name, sr=None)\n",
|
||||
" audio_len = len(audio) / sr\n",
|
||||
" text_len = len(text)\n",
|
||||
|
@ -416,7 +419,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.5"
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
"# set some vars\n",
|
||||
"# TTS_PATH = \"/home/thorsten/___dev/tts/mozilla/TTS\"\n",
|
||||
"CONFIG_FILE = \"/path/to/config/config.json\"\n",
|
||||
"CHARS_TO_REMOVE = \".,:!?'\""
|
||||
"CHARS_TO_REMOVE = \".,:!?'\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -59,7 +59,8 @@
|
|||
"# extra imports that might not be included in requirements.txt\n",
|
||||
"import collections\n",
|
||||
"import operator\n",
|
||||
"\n"
|
||||
"\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -75,7 +76,7 @@
|
|||
"CONFIG = load_config(CONFIG_FILE)\n",
|
||||
"\n",
|
||||
"# Load some properties from config.json\n",
|
||||
"CONFIG_METADATA = sorted(load_tts_samples(CONFIG.datasets)[0])\n",
|
||||
"CONFIG_METADATA = load_tts_samples(CONFIG.datasets)[0]\n",
|
||||
"CONFIG_METADATA = CONFIG_METADATA\n",
|
||||
"CONFIG_DATASET = CONFIG.datasets[0]\n",
|
||||
"CONFIG_PHONEME_LANGUAGE = CONFIG.phoneme_language\n",
|
||||
|
@ -84,7 +85,10 @@
|
|||
"\n",
|
||||
"# Will be printed on generated output graph\n",
|
||||
"CONFIG_RUN_NAME = CONFIG.run_name\n",
|
||||
"CONFIG_RUN_DESC = CONFIG.run_description"
|
||||
"CONFIG_RUN_DESC = CONFIG.run_description\n",
|
||||
"\n",
|
||||
"# Needed to convert text to phonemes and phonemes to ids\n",
|
||||
"tokenizer, config = TTSTokenizer.init_from_config(CONFIG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -112,12 +116,13 @@
|
|||
"source": [
|
||||
"def get_phoneme_from_sequence(text):\n",
|
||||
" temp_list = []\n",
|
||||
" if len(text[0]) > 0:\n",
|
||||
" temp_text = text[0].rstrip('\\n')\n",
|
||||
" if len(text[\"text\"]) > 0:\n",
|
||||
" #temp_text = text[0].rstrip('\\n')\n",
|
||||
" temp_text = text[\"text\"].rstrip('\\n')\n",
|
||||
" for rm_bad_chars in CHARS_TO_REMOVE:\n",
|
||||
" temp_text = temp_text.replace(rm_bad_chars,\"\")\n",
|
||||
" seq = phoneme_to_sequence(temp_text, [CONFIG_TEXT_CLEANER], CONFIG_PHONEME_LANGUAGE, CONFIG_ENABLE_EOS_BOS_CHARS)\n",
|
||||
" text = sequence_to_phoneme(seq)\n",
|
||||
" seq = tokenizer.text_to_ids(temp_text)\n",
|
||||
" text = tokenizer.ids_to_text(seq)\n",
|
||||
" text = text.replace(\" \",\"\")\n",
|
||||
" temp_list.append(text)\n",
|
||||
" return temp_list"
|
||||
|
@ -229,7 +234,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -243,7 +248,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -48,7 +48,6 @@ config = TacotronConfig(
|
|||
precompute_num_workers=24,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
ga_alpha=0.0,
|
||||
r=2,
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||
|
@ -68,16 +67,15 @@ config = TacotronConfig(
|
|||
datasets=[dataset_config],
|
||||
lr=1e-3,
|
||||
lr_scheduler="StepwiseGradualLR",
|
||||
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e5, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
|
||||
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e4, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
|
||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||
# Need to experiment with these below for capacitron
|
||||
loss_masking=False,
|
||||
decoder_loss_alpha=1.0,
|
||||
postnet_loss_alpha=1.0,
|
||||
postnet_diff_spec_alpha=0.0,
|
||||
decoder_diff_spec_alpha=0.0,
|
||||
decoder_ssim_alpha=0.0,
|
||||
postnet_ssim_alpha=0.0,
|
||||
postnet_diff_spec_alpha=1.0,
|
||||
decoder_diff_spec_alpha=1.0,
|
||||
decoder_ssim_alpha=1.0,
|
||||
postnet_ssim_alpha=1.0,
|
||||
)
|
||||
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
|
|
@ -52,7 +52,6 @@ config = Tacotron2Config(
|
|||
precompute_num_workers=24,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
ga_alpha=0.0,
|
||||
r=2,
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||
|
@ -77,23 +76,20 @@ config = Tacotron2Config(
|
|||
"gradual_learning_rates": [
|
||||
[0, 1e-3],
|
||||
[2e4, 5e-4],
|
||||
[4e5, 3e-4],
|
||||
[4e4, 3e-4],
|
||||
[6e4, 1e-4],
|
||||
[8e4, 5e-5],
|
||||
]
|
||||
},
|
||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||
# dashboard_logger='wandb',
|
||||
# sort_by_audio_len=True,
|
||||
seq_len_norm=True,
|
||||
# Need to experiment with these below for capacitron
|
||||
loss_masking=False,
|
||||
decoder_loss_alpha=1.0,
|
||||
postnet_loss_alpha=1.0,
|
||||
postnet_diff_spec_alpha=0.0,
|
||||
decoder_diff_spec_alpha=0.0,
|
||||
decoder_ssim_alpha=0.0,
|
||||
postnet_ssim_alpha=0.0,
|
||||
postnet_diff_spec_alpha=1.0,
|
||||
decoder_diff_spec_alpha=1.0,
|
||||
decoder_ssim_alpha=1.0,
|
||||
postnet_ssim_alpha=1.0,
|
||||
)
|
||||
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
|
|
@ -54,7 +54,6 @@ config = FastPitchConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -53,7 +53,6 @@ config = FastSpeechConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -68,7 +68,6 @@ config = Tacotron2Config(
|
|||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
seq_len_norm=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
|||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=45,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
|
@ -37,7 +23,7 @@ config = VitsConfig(
|
|||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=5,
|
||||
num_loader_workers=0,
|
||||
num_loader_workers=8,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
|
@ -52,6 +38,7 @@ config = VitsConfig(
|
|||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
cudnn_benchmark=False,
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
|
|
|
@ -3,11 +3,10 @@ from glob import glob
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
|
||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
@ -22,22 +21,13 @@ dataset_config = [
|
|||
for path in dataset_paths
|
||||
]
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=16000,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=False,
|
||||
trim_db=23.0,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=True,
|
||||
do_amp_to_db_linear=False,
|
||||
resample=False,
|
||||
)
|
||||
|
||||
vitsArgs = VitsArgs(
|
||||
|
@ -69,7 +59,6 @@ config = VitsConfig(
|
|||
use_language_weighted_sampler=True,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
min_audio_len=32 * 256 * 4,
|
||||
max_audio_len=160000,
|
||||
output_path=output_path,
|
||||
|
|
|
@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
|
|||
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
||||
"Vor dem 22. November 1963.",
|
||||
],
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.downloaders import download_thorsten_de
|
||||
|
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
|
|||
print("Downloading dataset")
|
||||
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=45,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# take the scripts's parent's directory to prefix all the output paths.
|
||||
RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
echo $RUN_DIR
|
||||
# download LJSpeech dataset
|
||||
# download VCTK dataset
|
||||
wget https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip -O VCTK-Corpus-0.92.zip
|
||||
# extract
|
||||
mkdir VCTK
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
|
|||
)
|
||||
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=23.0,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
resample=True,
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||
)
|
||||
|
||||
vitsArgs = VitsArgs(
|
||||
|
@ -62,6 +47,7 @@ config = VitsConfig(
|
|||
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
cudnn_benchmark=False,
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# core deps
|
||||
numpy==1.21.6
|
||||
numpy==1.21.6;python_version<"3.10"
|
||||
numpy==1.22.4;python_version=="3.10"
|
||||
cython==0.29.28
|
||||
scipy>=1.4.0
|
||||
torch>=1.7
|
||||
torchaudio
|
||||
soundfile
|
||||
librosa==0.8.0
|
||||
numba==0.55.1
|
||||
inflect
|
||||
numba==0.55.1;python_version<"3.10"
|
||||
numba==0.55.2;python_version=="3.10"
|
||||
inflect==5.6.0
|
||||
tqdm
|
||||
anyascii
|
||||
pyyaml
|
||||
|
|
|
@ -4,5 +4,4 @@ TF_CPP_MIN_LOG_LEVEL=3
|
|||
# runtime bash based tests
|
||||
# TODO: move these to python
|
||||
./tests/bash_tests/test_demo_server.sh && \
|
||||
./tests/bash_tests/test_resample.sh && \
|
||||
./tests/bash_tests/test_compute_statistics.sh
|
||||
|
|
2
setup.py
2
setup.py
|
@ -90,7 +90,7 @@ setup(
|
|||
# ext_modules=find_cython_extensions(),
|
||||
# package
|
||||
include_package_data=True,
|
||||
packages=find_packages(include=["TTS*"]),
|
||||
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
|
||||
package_data={
|
||||
"TTS": [
|
||||
"VERSION",
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
|
||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||
from TTS.config import BaseAudioConfig
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
import math
|
||||
import os
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
from coqpit import Coqpit
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||
from TTS.utils.audio import numpy_transforms as np_transforms
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
os.makedirs(OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
|
||||
|
||||
class TestNumpyTransforms(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
@dataclass
|
||||
class AudioConfig(Coqpit):
|
||||
sample_rate: int = 22050
|
||||
fft_size: int = 1024
|
||||
num_mels: int = 256
|
||||
mel_fmax: int = 1800
|
||||
mel_fmin: int = 0
|
||||
hop_length: int = 256
|
||||
win_length: int = 1024
|
||||
pitch_fmax: int = 450
|
||||
trim_db: int = -1
|
||||
min_silence_sec: float = 0.01
|
||||
gain: float = 1.0
|
||||
base: float = 10.0
|
||||
|
||||
self.config = AudioConfig()
|
||||
self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate)
|
||||
|
||||
def test_build_mel_basis(self):
|
||||
"""Check if the mel basis is correctly built"""
|
||||
print(" > Testing mel basis building.")
|
||||
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||
self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1))
|
||||
|
||||
def test_millisec_to_length(self):
|
||||
"""Check if the conversion from milliseconds to length is correct"""
|
||||
print(" > Testing millisec to length conversion.")
|
||||
win_len, hop_len = np_transforms.millisec_to_length(
|
||||
frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate
|
||||
)
|
||||
self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate))
|
||||
self.assertEqual(win_len, self.config.sample_rate)
|
||||
|
||||
def test_amplitude_db_conversion(self):
|
||||
di = np.random.rand(11)
|
||||
o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10)
|
||||
o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10)
|
||||
np.testing.assert_almost_equal(di, o2, decimal=5)
|
||||
|
||||
def test_preemphasis_deemphasis(self):
|
||||
di = np.random.rand(11)
|
||||
o1 = np_transforms.preemphasis(x=di, coeff=0.95)
|
||||
o2 = np_transforms.deemphasis(x=o1, coeff=0.95)
|
||||
np.testing.assert_almost_equal(di, o2, decimal=5)
|
||||
|
||||
def test_spec_to_mel(self):
|
||||
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||
spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T]
|
||||
mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis)
|
||||
self.assertEqual(mel.shape, (self.config.num_mels, 20))
|
||||
|
||||
def mel_to_spec(self):
|
||||
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||
mel = np.random.rand(self.config.num_mels, 20) # [C, T]
|
||||
spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis)
|
||||
self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20))
|
||||
|
||||
def test_wav_to_spec(self):
|
||||
spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config)
|
||||
self.assertEqual(
|
||||
spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
|
||||
)
|
||||
|
||||
def test_wav_to_mel(self):
|
||||
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
|
||||
self.assertEqual(
|
||||
mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
|
||||
)
|
||||
|
||||
def test_compute_f0(self):
|
||||
pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config)
|
||||
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
|
||||
assert pitch.shape[0] == mel.shape[1]
|
||||
|
||||
def test_load_wav(self):
|
||||
wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050)
|
||||
wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000)
|
||||
self.assertEqual(wav.shape, (self.sample_wav.shape[0],))
|
||||
self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],))
|
|
@ -1,29 +0,0 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class TestRemoveSilenceVAD(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test():
|
||||
# set paths
|
||||
wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
|
||||
output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
|
||||
|
||||
# resample audios
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
|
||||
)
|
||||
|
||||
# run test
|
||||
run_cli(
|
||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
|
||||
)
|
||||
run_cli(f'rm -rf "{output_resample_path}"')
|
||||
run_cli(f'rm -rf "{output_path}"')
|
|
@ -1,16 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
set -xe
|
||||
BASEDIR=$(dirname "$0")
|
||||
TARGET_SR=16000
|
||||
echo "$BASEDIR"
|
||||
#run the resample script
|
||||
python TTS/bin/resample.py --input_dir $BASEDIR/../data/ljspeech --output_dir $BASEDIR/outputs/resample_tests --output_sr $TARGET_SR
|
||||
#check samplerate of output
|
||||
OUT_SR=$( (echo "import librosa" ; echo "y, sr = librosa.load('"$BASEDIR"/outputs/resample_tests/wavs/LJ001-0012.wav', sr=None)" ; echo "print(sr)") | python )
|
||||
OUT_SR=$(($OUT_SR + 0))
|
||||
if [[ $OUT_SR -ne $TARGET_SR ]]; then
|
||||
echo "Missmatch between target and output sample rates"
|
||||
exit 1
|
||||
fi
|
||||
#cleaning up
|
||||
rm -rf $BASEDIR/outputs/resample_tests
|
|
@ -5,11 +5,11 @@ import unittest
|
|||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
|
|||
else:
|
||||
len2 += 1
|
||||
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
||||
|
||||
def test_bucket_batch_sampler(self):
|
||||
bucket_size_multiplier = 2
|
||||
sampler = range(len(train_samples))
|
||||
sampler = BucketBatchSampler(
|
||||
sampler,
|
||||
data=train_samples,
|
||||
batch_size=7,
|
||||
drop_last=True,
|
||||
sort_key=lambda x: len(x["text"]),
|
||||
bucket_size_multiplier=bucket_size_multiplier,
|
||||
)
|
||||
|
||||
# check if the samples are sorted by text lenght whuile bucketing
|
||||
min_text_len_in_bucket = 0
|
||||
bucket_items = []
|
||||
for batch_idx, batch in enumerate(list(sampler)):
|
||||
if (batch_idx + 1) % bucket_size_multiplier == 0:
|
||||
for bucket_item in bucket_items:
|
||||
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
|
||||
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
|
||||
min_text_len_in_bucket = 0
|
||||
bucket_items = []
|
||||
else:
|
||||
bucket_items += batch
|
||||
|
||||
# check sampler length
|
||||
self.assertEqual(len(sampler), len(train_samples) // 7)
|
||||
|
|
|
@ -30,6 +30,13 @@ class TestTTSTokenizer(unittest.TestCase):
|
|||
test_hat = self.tokenizer_ph.ids_to_text(ids)
|
||||
self.assertEqual(text_ph, test_hat)
|
||||
|
||||
def test_text_to_ids_phonemes_punctuation(self):
|
||||
text = "..."
|
||||
text_ph = self.ph.phonemize(text, separator="")
|
||||
ids = self.tokenizer_ph.text_to_ids(text)
|
||||
test_hat = self.tokenizer_ph.ids_to_text(ids)
|
||||
self.assertEqual(text_ph, test_hat)
|
||||
|
||||
def test_text_to_ids_phonemes_with_eos_bos(self):
|
||||
text = "Bu bir Örnek."
|
||||
self.tokenizer_ph.use_eos_bos = True
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
import unittest
|
||||
|
||||
import torch as T
|
||||
|
||||
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
|
||||
|
||||
class L1LossMaskedTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=False)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
class MSELossMaskedTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
layer = MSELossMasked(seq_len_norm=False)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = MSELossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
class SSIMLossTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
layer = SSIMLoss()
|
||||
dummy_input = T.ones(4, 57, 128).float()
|
||||
dummy_target = T.ones(4, 57, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.arange(0, 4 * 57 * 128)
|
||||
dummy_input = dummy_input.reshape(4, 57, 128).float()
|
||||
dummy_target = T.arange(-4 * 57 * 128, 0)
|
||||
dummy_target = dummy_target.reshape(4, 57, 128).float()
|
||||
dummy_target = -dummy_target
|
||||
|
||||
dummy_length = (T.ones(4) * 58).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() >= 1.0, "0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 57, 128).float()
|
||||
dummy_target = T.zeros(4, 57, 128).float()
|
||||
dummy_length = (T.arange(54, 58)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
dummy_input = T.rand(4, 57, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(54, 58)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 57, 128).float()
|
||||
dummy_target = T.ones(4, 57, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 57, 128).float()
|
||||
dummy_target = T.zeros(4, 57, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 57, 128).float()
|
||||
dummy_target = T.zeros(4, 57, 128).float()
|
||||
dummy_length = (T.arange(54, 58)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 57, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(54, 58)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
class BCELossTest(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
layer = BCELossMasked(pos_weight=5.0)
|
||||
|
||||
length = T.tensor([95])
|
||||
target = (
|
||||
1.0 - sequence_mask(length - 1, 100).float()
|
||||
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
|
||||
zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding
|
||||
early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping
|
||||
late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping
|
||||
|
||||
loss = layer(true_x, target, length)
|
||||
self.assertEqual(loss.item(), 0.0)
|
||||
|
||||
loss = layer(early_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 2.1053, places=4)
|
||||
|
||||
loss = layer(late_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||
|
||||
loss = layer(zero_x, target, length)
|
||||
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||
|
||||
# pos_weight should be < 1 to penalize early stopping
|
||||
layer = BCELossMasked(pos_weight=0.2)
|
||||
loss = layer(true_x, target, length)
|
||||
self.assertEqual(loss.item(), 0.0)
|
||||
|
||||
# when pos_weight < 1 overweight the early stopping loss
|
||||
|
||||
loss_early = layer(early_x, target, length)
|
||||
loss_late = layer(late_x, target, length)
|
||||
self.assertGreater(loss_early.item(), loss_late.item())
|
|
@ -2,9 +2,7 @@ import unittest
|
|||
|
||||
import torch as T
|
||||
|
||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss
|
||||
from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
|
@ -85,131 +83,3 @@ class EncoderTests(unittest.TestCase):
|
|||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 8
|
||||
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||
|
||||
|
||||
class L1LossMaskedTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=False)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
|
||||
class SSIMLossTests(unittest.TestCase):
|
||||
def test_in_out(self): # pylint: disable=no-self-use
|
||||
# test input == target
|
||||
layer = SSIMLoss()
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
||||
# seq_len_norm = True
|
||||
# test input == target
|
||||
layer = L1LossMasked(seq_len_norm=True)
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.ones(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 0.0
|
||||
|
||||
# test input != target
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||
|
||||
# test if padded values of input makes any difference
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||
|
||||
dummy_input = T.rand(4, 8, 128).float()
|
||||
dummy_target = dummy_input.detach()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||
|
|
|
@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
|
|||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.tts.models.vits import (
|
||||
Vits,
|
||||
VitsArgs,
|
||||
VitsAudioConfig,
|
||||
amp_to_db,
|
||||
db_to_amp,
|
||||
load_audio,
|
||||
spec_to_mel,
|
||||
wav_to_mel,
|
||||
wav_to_spec,
|
||||
)
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
|
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
|
|||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling(self):
|
||||
"""Upsampling by the decoder upsampling layers"""
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||
model_args = VitsArgs(
|
||||
num_chars=32,
|
||||
spec_segment_size=10,
|
||||
|
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
|
|||
interpolate_z=False,
|
||||
upsample_rates_decoder=[8, 8, 4, 2],
|
||||
)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
|
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
|
|||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling_interpolation(self):
|
||||
"""Upsampling by interpolation"""
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||
model_args = VitsArgs(
|
||||
num_chars=32,
|
||||
spec_segment_size=10,
|
||||
encoder_sample_rate=11025,
|
||||
interpolate_z=True,
|
||||
upsample_rates_decoder=[8, 8, 2, 2],
|
||||
)
|
||||
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
|
|
Loading…
Reference in New Issue