mirror of https://github.com/coqui-ai/TTS.git
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
97de55595f
|
@ -42,16 +42,18 @@ jobs:
|
||||||
- uses: actions/setup-python@v2
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- run: |
|
- name: Install pip requirements
|
||||||
|
run: |
|
||||||
python -m pip install -U pip setuptools wheel build
|
python -m pip install -U pip setuptools wheel build
|
||||||
- run: |
|
python -m pip install -r requirements.txt
|
||||||
python -m build
|
- name: Setup and install manylinux1_x86_64 wheel
|
||||||
- run: |
|
run: |
|
||||||
python -m pip install dist/*.whl
|
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
|
||||||
|
python -m pip install dist/*-manylinux*.whl
|
||||||
- uses: actions/upload-artifact@v2
|
- uses: actions/upload-artifact@v2
|
||||||
with:
|
with:
|
||||||
name: wheel-${{ matrix.python-version }}
|
name: wheel-${{ matrix.python-version }}
|
||||||
path: dist/*.whl
|
path: dist/*-manylinux*.whl
|
||||||
publish-artifacts:
|
publish-artifacts:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
needs: [build-sdist, build-wheels]
|
needs: [build-sdist, build-wheels]
|
||||||
|
|
|
@ -11,4 +11,5 @@ recursive-include TTS *.md
|
||||||
recursive-include TTS *.py
|
recursive-include TTS *.py
|
||||||
recursive-include TTS *.pyx
|
recursive-include TTS *.pyx
|
||||||
recursive-include images *.png
|
recursive-include images *.png
|
||||||
|
recursive-exclude tests *
|
||||||
|
prune tests*
|
||||||
|
|
46
README.md
46
README.md
|
@ -130,7 +130,7 @@ pip install -e .[all,dev,notebooks] # Select the relevant extras
|
||||||
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a diffent OS.
|
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
|
||||||
$ make install
|
$ make install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -145,25 +145,61 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
||||||
```
|
```
|
||||||
$ tts --list_models
|
$ tts --list_models
|
||||||
```
|
```
|
||||||
|
- Get model info (for both tts_models and vocoder_models):
|
||||||
|
- Query by type/name:
|
||||||
|
The model_info_by_name uses the name as it from the --list_models.
|
||||||
|
```
|
||||||
|
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
|
||||||
|
```
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
|
||||||
|
```
|
||||||
|
```
|
||||||
|
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
|
||||||
|
```
|
||||||
|
- Query by type/idx:
|
||||||
|
The model_query_idx uses the corresponding idx from --list_models.
|
||||||
|
```
|
||||||
|
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
|
||||||
|
```
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --model_info_by_idx tts_models/3
|
||||||
|
```
|
||||||
|
|
||||||
- Run TTS with default models:
|
- Run TTS with default models:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS"
|
$ tts --text "Text for TTS" --out_path output/path/speech.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run a TTS model with its default vocoder model:
|
- Run a TTS model with its default vocoder model:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
|
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||||
|
```
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run with specific TTS and vocoder models from the list:
|
- Run with specific TTS and vocoder models from the list:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -215,6 +215,14 @@
|
||||||
"author": "@thorstenMueller",
|
"author": "@thorstenMueller",
|
||||||
"license": "apache 2.0",
|
"license": "apache 2.0",
|
||||||
"commit": "unknown"
|
"commit": "unknown"
|
||||||
|
},
|
||||||
|
"tacotron2-DDC": {
|
||||||
|
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--thorsten--tacotron2-DDC.zip",
|
||||||
|
"default_vocoder": "vocoder_models/de/thorsten/hifigan_v1",
|
||||||
|
"description": "Thorsten-Dec2021-22k-DDC",
|
||||||
|
"author": "@thorstenMueller",
|
||||||
|
"license": "apache 2.0",
|
||||||
|
"commit": "unknown"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -460,6 +468,13 @@
|
||||||
"author": "@thorstenMueller",
|
"author": "@thorstenMueller",
|
||||||
"license": "apache 2.0",
|
"license": "apache 2.0",
|
||||||
"commit": "unknown"
|
"commit": "unknown"
|
||||||
|
},
|
||||||
|
"hifigan_v1": {
|
||||||
|
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/vocoder_models--de--thorsten--hifigan_v1.zip",
|
||||||
|
"description": "HifiGAN vocoder model for Thorsten Neutral Dec2021 22k Samplerate Tacotron2 DDC model",
|
||||||
|
"author": "@thorstenMueller",
|
||||||
|
"license": "apache 2.0",
|
||||||
|
"commit": "unknown"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.7.1
|
0.8.0
|
|
@ -60,13 +60,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
- Run a TTS model with its default vocoder model:
|
- Run a TTS model with its default vocoder model:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>"
|
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run with specific TTS and vocoder models from the list:
|
- Run with specific TTS and vocoder models from the list:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
|
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
- Run your own TTS model (Using Griffin-Lim Vocoder):
|
||||||
|
|
|
@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
|
||||||
|
|
||||||
from TTS.encoder.dataset import EncoderDataset
|
from TTS.encoder.dataset import EncoderDataset
|
||||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
|
||||||
from TTS.encoder.utils.training import init_training
|
from TTS.encoder.utils.training import init_training
|
||||||
from TTS.encoder.utils.visual import plot_embeddings
|
from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
||||||
import argparse
|
import argparse
|
||||||
from itertools import product as cartesian_product
|
from itertools import product as cartesian_product
|
||||||
|
|
||||||
|
@ -7,94 +7,97 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
from TTS.vocoder.models import setup_model
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
if __name__ == "__main__":
|
||||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||||
parser.add_argument(
|
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||||
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for."
|
parser.add_argument(
|
||||||
)
|
"--num_iter",
|
||||||
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.")
|
type=int,
|
||||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
help="Number of model inference iterations that you like to optimize noise schedule for.",
|
||||||
parser.add_argument(
|
)
|
||||||
"--search_depth",
|
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
|
||||||
type=int,
|
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||||
default=3,
|
parser.add_argument(
|
||||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
"--search_depth",
|
||||||
)
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||||
|
)
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
|
|
||||||
# setup audio processor
|
# setup audio processor
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
_, train_data = load_wav_data(args.data_path, 0)
|
_, train_data = load_wav_data(args.data_path, 0)
|
||||||
train_data = train_data[: args.num_samples]
|
train_data = train_data[: args.num_samples]
|
||||||
dataset = WaveGradDataset(
|
dataset = WaveGradDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
items=train_data,
|
items=train_data,
|
||||||
seq_len=-1,
|
seq_len=-1,
|
||||||
hop_len=ap.hop_length,
|
hop_len=ap.hop_length,
|
||||||
pad_short=config.pad_short,
|
pad_short=config.pad_short,
|
||||||
conv_pad=config.conv_pad,
|
conv_pad=config.conv_pad,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
return_segments=False,
|
return_segments=False,
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_full_clips,
|
collate_fn=dataset.collate_full_clips,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
num_workers=config.num_loader_workers,
|
num_workers=config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup the model
|
# setup the model
|
||||||
model = setup_generator(config)
|
model = setup_model(config)
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
# setup optimization parameters
|
# setup optimization parameters
|
||||||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||||
print(base_values)
|
print(f" > base values: {base_values}")
|
||||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||||
best_error = float("inf")
|
best_error = float("inf")
|
||||||
best_schedule = None
|
best_schedule = None # pylint: disable=C0103
|
||||||
total_search_iter = len(base_values) ** args.num_iter
|
total_search_iter = len(base_values) ** args.num_iter
|
||||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||||
beta = exponents * base
|
beta = exponents * base
|
||||||
model.compute_noise_level(beta)
|
model.compute_noise_level(beta)
|
||||||
for data in loader:
|
for data in loader:
|
||||||
mel, audio = data
|
mel, audio = data
|
||||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||||
|
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
y_hat = y_hat.cpu()
|
y_hat = y_hat.cpu()
|
||||||
y_hat = y_hat.numpy()
|
y_hat = y_hat.numpy()
|
||||||
|
|
||||||
mel_hat = []
|
mel_hat = []
|
||||||
for i in range(y_hat.shape[0]):
|
for i in range(y_hat.shape[0]):
|
||||||
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||||
mel_hat.append(torch.from_numpy(m))
|
mel_hat.append(torch.from_numpy(m))
|
||||||
|
|
||||||
mel_hat = torch.stack(mel_hat)
|
mel_hat = torch.stack(mel_hat)
|
||||||
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||||
if mse.item() < best_error:
|
if mse.item() < best_error:
|
||||||
best_error = mse.item()
|
best_error = mse.item()
|
||||||
best_schedule = {"beta": beta}
|
best_schedule = {"beta": beta}
|
||||||
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||||
np.save(args.output_path, best_schedule)
|
np.save(args.output_path, best_schedule)
|
||||||
|
|
|
@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> None:
|
def load_config(config_path: str) -> Coqpit:
|
||||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||||
to find the corresponding Config class. Then initialize the Config.
|
to find the corresponding Config class. Then initialize the Config.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
# from TTS.utils.audio import TorchSTFT
|
# from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,9 +200,6 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
loss_masking (bool):
|
loss_masking (bool):
|
||||||
enable / disable masking loss values against padded segments of samples in a batch.
|
enable / disable masking loss values against padded segments of samples in a batch.
|
||||||
|
|
||||||
sort_by_audio_len (bool):
|
|
||||||
If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`.
|
|
||||||
|
|
||||||
min_text_len (int):
|
min_text_len (int):
|
||||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||||
|
|
||||||
|
@ -303,7 +300,6 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
batch_group_size: int = 0
|
batch_group_size: int = 0
|
||||||
loss_masking: bool = None
|
loss_masking: bool = None
|
||||||
# dataloading
|
# dataloading
|
||||||
sort_by_audio_len: bool = False
|
|
||||||
min_audio_len: int = 1
|
min_audio_len: int = 1
|
||||||
max_audio_len: int = float("inf")
|
max_audio_len: int = float("inf")
|
||||||
min_text_len: int = 1
|
min_text_len: int = 1
|
||||||
|
|
|
@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
||||||
stopnet_pos_weight (float):
|
stopnet_pos_weight (float):
|
||||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||||
datasets with longer sentences. Defaults to 10.
|
datasets with longer sentences. Defaults to 0.2.
|
||||||
max_decoder_steps (int):
|
max_decoder_steps (int):
|
||||||
Max number of steps allowed for the decoder. Defaults to 50.
|
Max number of steps allowed for the decoder. Defaults to 50.
|
||||||
encoder_in_features (int):
|
encoder_in_features (int):
|
||||||
|
@ -161,8 +161,8 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
prenet_dropout_at_inference: bool = False
|
prenet_dropout_at_inference: bool = False
|
||||||
stopnet: bool = True
|
stopnet: bool = True
|
||||||
separate_stopnet: bool = True
|
separate_stopnet: bool = True
|
||||||
stopnet_pos_weight: float = 10.0
|
stopnet_pos_weight: float = 0.2
|
||||||
max_decoder_steps: int = 500
|
max_decoder_steps: int = 10000
|
||||||
encoder_in_features: int = 256
|
encoder_in_features: int = 256
|
||||||
decoder_in_features: int = 256
|
decoder_in_features: int = 256
|
||||||
decoder_output_dim: int = 80
|
decoder_output_dim: int = 80
|
||||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.models.vits import VitsArgs
|
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
|
||||||
model_args (VitsArgs):
|
model_args (VitsArgs):
|
||||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||||
|
|
||||||
|
audio (VitsAudioConfig):
|
||||||
|
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
||||||
|
|
||||||
grad_clip (List):
|
grad_clip (List):
|
||||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||||
|
|
||||||
|
@ -67,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec (bool):
|
compute_linear_spec (bool):
|
||||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
use_weighted_sampler (bool):
|
||||||
|
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||||
|
|
||||||
|
weighted_sampler_attrs (dict):
|
||||||
|
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||||
|
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||||
|
|
||||||
|
weighted_sampler_multipliers (dict):
|
||||||
|
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||||
|
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||||
|
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||||
|
|
||||||
r (int):
|
r (int):
|
||||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
@ -94,6 +109,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
model: str = "vits"
|
model: str = "vits"
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||||
|
audio: VitsAudioConfig = VitsAudioConfig()
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||||
|
@ -120,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
compute_linear_spec: bool = True
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# sampler params
|
||||||
|
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||||
|
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||||
|
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
|
@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
||||||
"audio_file": audio_path,
|
"audio_file": audio_path,
|
||||||
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
||||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||||
|
"root_path": root_path,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if not_found_counter > 0:
|
if not_found_counter > 0:
|
||||||
|
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("\t")
|
cols = line.split("\t")
|
||||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
wav_file = cols[1].strip()
|
wav_file = cols[1].strip()
|
||||||
text = cols[0].strip()
|
text = cols[0].strip()
|
||||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
||||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||||
if os.path.isfile(wav_file):
|
if os.path.isfile(wav_file):
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# M-AI-Labs have some missing samples, so just print the warning
|
# M-AI-Labs have some missing samples, so just print the warning
|
||||||
print("> File %s does not exist!" % (wav_file))
|
print("> File %s does not exist!" % (wav_file))
|
||||||
|
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
||||||
if not os.path.exists(wav_file):
|
if not os.path.exists(wav_file):
|
||||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
||||||
if speaker_name in ignored_speakers:
|
if speaker_name in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_name in ignored_speakers:
|
if speaker_name in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
items.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"audio_file": wav_file,
|
||||||
|
"speaker_name": f"LTTS_{speaker_name}",
|
||||||
|
"root_path": root_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
for item in items:
|
for item in items:
|
||||||
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
||||||
return items
|
return items
|
||||||
|
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
||||||
skipped_files.append(wav_file)
|
skipped_files.append(wav_file)
|
||||||
continue
|
continue
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
||||||
else:
|
else:
|
||||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||||
if os.path.exists(wav_file):
|
if os.path.exists(wav_file):
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f" [!] wav files don't exist - {wav_file}")
|
print(f" [!] wav files don't exist - {wav_file}")
|
||||||
return items
|
return items
|
||||||
|
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
|
||||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||||
text = file_text.readlines()[0]
|
text = file_text.readlines()[0]
|
||||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
|
||||||
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
||||||
with open(txt_file, "r", encoding="utf-8") as file_text:
|
with open(txt_file, "r", encoding="utf-8") as file_text:
|
||||||
text = file_text.readlines()[0]
|
text = file_text.readlines()[0]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
|
||||||
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
||||||
continue
|
continue
|
||||||
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker in ignored_speakers:
|
if speaker in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
items.append(
|
||||||
|
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
wav_name, text = line.rstrip("\n").split("|")
|
wav_name, text = line.rstrip("\n").split("|")
|
||||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -554,7 +576,7 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2].replace(" ", "")
|
text = cols[2].replace(" ", "")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -67,9 +67,14 @@ class WN(torch.nn.Module):
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
dilation = dilation_rate**i
|
dilation = dilation_rate**i
|
||||||
padding = int((kernel_size * dilation - dilation) / 2)
|
padding = int((kernel_size * dilation - dilation) / 2)
|
||||||
in_layer = torch.nn.Conv1d(
|
if i == 0:
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
in_layer = torch.nn.Conv1d(
|
||||||
)
|
in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
in_layer = torch.nn.Conv1d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||||
|
)
|
||||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||||
self.in_layers.append(in_layer)
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,11 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
||||||
|
|
||||||
|
|
||||||
def unsqueeze(x, x_mask=None, num_sqz=2):
|
def unsqueeze(x, x_mask=None, num_sqz=2):
|
||||||
"""GlowTTS unsqueeze operation
|
"""GlowTTS unsqueeze operation (revert the squeeze)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
each 's' is a n-dimensional vector.
|
each 's' is a n-dimensional vector.
|
||||||
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
|
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]``
|
||||||
"""
|
"""
|
||||||
b, c, t = x.size()
|
b, c, t = x.size()
|
||||||
|
|
||||||
|
|
|
@ -197,7 +197,7 @@ class CouplingBlock(nn.Module):
|
||||||
end.bias.data.zero_()
|
end.bias.data.zero_()
|
||||||
self.end = end
|
self.end = end
|
||||||
# coupling layers
|
# coupling layers
|
||||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
|
self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
|
||||||
|
|
||||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,8 +7,8 @@ from torch import nn
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
|
||||||
from TTS.utils.audio import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
@ -91,30 +91,55 @@ class MSELossMasked(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Min-Max normalize tensor through first dimension
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, D1, D2]`
|
||||||
|
- m: :math:`[B, D1, 1]`
|
||||||
|
"""
|
||||||
|
maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True)
|
||||||
|
minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True)
|
||||||
|
return (x - minimum) / (maximum - minimum + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
class SSIMLoss(torch.nn.Module):
|
class SSIMLoss(torch.nn.Module):
|
||||||
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
|
"""SSIM loss as (1 - SSIM)
|
||||||
|
SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss_func = ssim
|
self.loss_func = _SSIMLoss()
|
||||||
|
|
||||||
def forward(self, y_hat, y, length=None):
|
def forward(self, y_hat, y, length):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y_hat (tensor): model prediction values.
|
y_hat (tensor): model prediction values.
|
||||||
y (tensor): target values.
|
y (tensor): target values.
|
||||||
length (tensor): length of each sample in a batch.
|
length (tensor): length of each sample in a batch for masking.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
y_hat: B x T X D
|
y_hat: B x T X D
|
||||||
y: B x T x D
|
y: B x T x D
|
||||||
length: B
|
length: B
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
if length is not None:
|
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2)
|
||||||
m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
|
y_norm = sample_wise_min_max(y, mask)
|
||||||
y_hat, y = y_hat * m, y * m
|
y_hat_norm = sample_wise_min_max(y_hat, mask)
|
||||||
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
|
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
|
||||||
|
|
||||||
|
if ssim_loss.item() > 1.0:
|
||||||
|
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||||
|
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||||
|
|
||||||
|
if ssim_loss.item() < 0.0:
|
||||||
|
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||||
|
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||||
|
|
||||||
|
return ssim_loss
|
||||||
|
|
||||||
|
|
||||||
class AttentionEntropyLoss(nn.Module):
|
class AttentionEntropyLoss(nn.Module):
|
||||||
|
@ -123,9 +148,6 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Forces attention to be more decisive by penalizing
|
Forces attention to be more decisive by penalizing
|
||||||
soft attention weights
|
soft attention weights
|
||||||
|
|
||||||
TODO: arguments
|
|
||||||
TODO: unit_test
|
|
||||||
"""
|
"""
|
||||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||||
loss = (entropy / np.log(align.shape[1])).mean()
|
loss = (entropy / np.log(align.shape[1])).mean()
|
||||||
|
@ -133,9 +155,17 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class BCELossMasked(nn.Module):
|
class BCELossMasked(nn.Module):
|
||||||
def __init__(self, pos_weight):
|
"""BCE loss with masking.
|
||||||
|
|
||||||
|
Used mainly for stopnet in autoregressive models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pos_weight: float = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_weight = pos_weight
|
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
|
||||||
|
|
||||||
def forward(self, x, target, length):
|
def forward(self, x, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -155,16 +185,17 @@ class BCELossMasked(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value in range [0, 1] masked by the length.
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
"""
|
"""
|
||||||
# mask: (batch, max_len, 1)
|
|
||||||
target.requires_grad = False
|
target.requires_grad = False
|
||||||
if length is not None:
|
if length is not None:
|
||||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
# mask: (batch, max_len, 1)
|
||||||
x = x * mask
|
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||||
target = target * mask
|
|
||||||
num_items = mask.sum()
|
num_items = mask.sum()
|
||||||
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
|
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||||
num_items = torch.numel(x)
|
num_items = torch.numel(x)
|
||||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
|
||||||
loss = loss / num_items
|
loss = loss / num_items
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,7 @@ class CapacitronVAE(nn.Module):
|
||||||
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
|
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
|
||||||
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
|
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
|
||||||
if speaker_embedding is not None:
|
if speaker_embedding is not None:
|
||||||
|
speaker_embedding = torch.squeeze(speaker_embedding)
|
||||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||||
|
|
||||||
# Feed the output of the ref encoder and information about text/speaker into
|
# Feed the output of the ref encoder and information about text/speaker into
|
||||||
|
|
|
@ -137,7 +137,7 @@ class BaseTTS(BaseTrainerModel):
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager"):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embeddings()
|
d_vector = self.speaker_manager.get_random_embedding()
|
||||||
else:
|
else:
|
||||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
|
|
|
@ -514,7 +514,7 @@ class GlowTTS(BaseTTS):
|
||||||
y = y[:, :, :y_max_length]
|
y = y[:, :, :y_max_length]
|
||||||
if attn is not None:
|
if attn is not None:
|
||||||
attn = attn[:, :, :, :y_max_length]
|
attn = attn[:, :, :, :y_max_length]
|
||||||
y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
|
y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze
|
||||||
return y, y_lengths, y_max_length, attn
|
return y, y_lengths, y_max_length, attn
|
||||||
|
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -13,6 +14,8 @@ from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
|
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
from TTS.utils.samplers import BucketBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
@ -200,11 +205,51 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# CONFIGS
|
||||||
|
#############################
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VitsAudioConfig(Coqpit):
|
||||||
|
fft_size: int = 1024
|
||||||
|
sample_rate: int = 22050
|
||||||
|
win_length: int = 1024
|
||||||
|
hop_length: int = 256
|
||||||
|
num_mels: int = 80
|
||||||
|
mel_fmin: int = 0
|
||||||
|
mel_fmax: int = None
|
||||||
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# DATASET
|
# DATASET
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
|
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||||
|
"""Create inverse frequency weights for balancing the dataset.
|
||||||
|
Use `multi_dict` to scale relative weights."""
|
||||||
|
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||||
|
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||||
|
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||||
|
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||||
|
weight_attr = 1.0 / attr_count
|
||||||
|
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
if multi_dict is not None:
|
||||||
|
# check if all keys are in the multi_dict
|
||||||
|
for k in multi_dict:
|
||||||
|
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||||
|
# scale weights
|
||||||
|
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||||
|
dataset_samples_weight *= multiplier_samples
|
||||||
|
return (
|
||||||
|
torch.from_numpy(dataset_samples_weight).float(),
|
||||||
|
unique_attr_names,
|
||||||
|
np.unique(dataset_samples_weight).tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, model_args, *args, **kwargs):
|
def __init__(self, model_args, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -786,7 +831,7 @@ class Vits(BaseTTS):
|
||||||
print(" > Text Encoder was reinit.")
|
print(" > Text Encoder was reinit.")
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
|
||||||
def _freeze_layers(self):
|
def _freeze_layers(self):
|
||||||
|
@ -817,7 +862,7 @@ class Vits(BaseTTS):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
sid, g, lid = None, None, None
|
sid, g, lid, durations = None, None, None, None
|
||||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
sid = aux_input["speaker_ids"]
|
sid = aux_input["speaker_ids"]
|
||||||
if sid.ndim == 0:
|
if sid.ndim == 0:
|
||||||
|
@ -832,7 +877,10 @@ class Vits(BaseTTS):
|
||||||
if lid.ndim == 0:
|
if lid.ndim == 0:
|
||||||
lid = lid.unsqueeze_(0)
|
lid = lid.unsqueeze_(0)
|
||||||
|
|
||||||
return sid, g, lid
|
if "durations" in aux_input and aux_input["durations"] is not None:
|
||||||
|
durations = aux_input["durations"]
|
||||||
|
|
||||||
|
return sid, g, lid, durations
|
||||||
|
|
||||||
def _set_speaker_input(self, aux_input: Dict):
|
def _set_speaker_input(self, aux_input: Dict):
|
||||||
d_vectors = aux_input.get("d_vectors", None)
|
d_vectors = aux_input.get("d_vectors", None)
|
||||||
|
@ -946,7 +994,7 @@ class Vits(BaseTTS):
|
||||||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||||
"""
|
"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.args.use_speaker_embedding and sid is not None:
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -1028,7 +1076,9 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
self,
|
||||||
|
x,
|
||||||
|
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Note:
|
Note:
|
||||||
|
@ -1048,7 +1098,7 @@ class Vits(BaseTTS):
|
||||||
- m_p: :math:`[B, C, T_dec]`
|
- m_p: :math:`[B, C, T_dec]`
|
||||||
- logs_p: :math:`[B, C, T_dec]`
|
- logs_p: :math:`[B, C, T_dec]`
|
||||||
"""
|
"""
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||||
x_lengths = self._set_x_lengths(x, aux_input)
|
x_lengths = self._set_x_lengths(x, aux_input)
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
|
@ -1062,21 +1112,25 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if durations is None:
|
||||||
logw = self.duration_predictor(
|
if self.args.use_sdp:
|
||||||
x,
|
logw = self.duration_predictor(
|
||||||
x_mask,
|
x,
|
||||||
g=g if self.args.condition_dp_on_speaker else None,
|
x_mask,
|
||||||
reverse=True,
|
g=g if self.args.condition_dp_on_speaker else None,
|
||||||
noise_scale=self.inference_noise_scale_dp,
|
reverse=True,
|
||||||
lang_emb=lang_emb,
|
noise_scale=self.inference_noise_scale_dp,
|
||||||
)
|
lang_emb=lang_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logw = self.duration_predictor(
|
||||||
|
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||||
|
)
|
||||||
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(
|
assert durations.shape[-1] == x.shape[-1]
|
||||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
w = durations.unsqueeze(0)
|
||||||
)
|
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||||
|
@ -1341,7 +1395,7 @@ class Vits(BaseTTS):
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager"):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embeddings()
|
d_vector = self.speaker_manager.get_random_embedding()
|
||||||
else:
|
else:
|
||||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
|
@ -1485,6 +1539,42 @@ class Vits(BaseTTS):
|
||||||
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
||||||
|
weights = None
|
||||||
|
data_items = dataset.samples
|
||||||
|
if getattr(config, "use_weighted_sampler", False):
|
||||||
|
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||||
|
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||||
|
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||||
|
print(multi_dict)
|
||||||
|
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||||
|
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||||
|
)
|
||||||
|
weights = weights * alpha
|
||||||
|
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||||
|
|
||||||
|
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
w_sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
batch_sampler = BucketBatchSampler(
|
||||||
|
w_sampler,
|
||||||
|
data=data_items,
|
||||||
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
sort_key=lambda x: os.path.getsize(x["audio_file"]),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_sampler = None
|
||||||
|
# sampler for DDP
|
||||||
|
if batch_sampler is None:
|
||||||
|
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||||
|
batch_sampler = (
|
||||||
|
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
|
||||||
|
) # TODO: check batch_sampler with multi-gpu
|
||||||
|
return batch_sampler
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
@ -1523,17 +1613,24 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# get samplers
|
# get samplers
|
||||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||||
|
if sampler is None:
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
shuffle=False, # shuffle is done in the dataset.
|
shuffle=False, # shuffle is done in the dataset.
|
||||||
drop_last=False, # setting this False might cause issues in AMP training.
|
collate_fn=dataset.collate_fn,
|
||||||
sampler=sampler,
|
drop_last=False, # setting this False might cause issues in AMP training.
|
||||||
collate_fn=dataset.collate_fn,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
pin_memory=False,
|
||||||
pin_memory=False,
|
)
|
||||||
)
|
else:
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=sampler,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
|
@ -1590,7 +1687,7 @@ class Vits(BaseTTS):
|
||||||
strict=True,
|
strict=True,
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||||
# as it is probably easier for model distribution.
|
# as it is probably easier for model distribution.
|
||||||
|
|
|
@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_
|
||||||
index_start = segment_indices[i]
|
index_start = segment_indices[i]
|
||||||
index_end = index_start + segment_size
|
index_end = index_start + segment_size
|
||||||
x_i = x[i]
|
x_i = x[i]
|
||||||
if pad_short and index_end > x.size(2):
|
if pad_short and index_end >= x.size(2):
|
||||||
# pad the sample if it is shorter than the segment size
|
# pad the sample if it is shorter than the segment size
|
||||||
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2)))
|
||||||
segments[i] = x_i[:, index_start:index_end]
|
segments[i] = x_i[:, index_start:index_end]
|
||||||
|
@ -107,16 +107,16 @@ def rand_segments(
|
||||||
T = segment_size
|
T = segment_size
|
||||||
if _x_lenghts is None:
|
if _x_lenghts is None:
|
||||||
_x_lenghts = T
|
_x_lenghts = T
|
||||||
len_diff = _x_lenghts - segment_size + 1
|
len_diff = _x_lenghts - segment_size
|
||||||
if let_short_samples:
|
if let_short_samples:
|
||||||
_x_lenghts[len_diff < 0] = segment_size
|
_x_lenghts[len_diff < 0] = segment_size
|
||||||
len_diff = _x_lenghts - segment_size + 1
|
len_diff = _x_lenghts - segment_size
|
||||||
else:
|
else:
|
||||||
assert all(
|
assert all(
|
||||||
len_diff > 0
|
len_diff > 0
|
||||||
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}"
|
||||||
segment_indices = (torch.rand([B]).type_as(x) * len_diff).long()
|
segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long()
|
||||||
ret = segment(x, segment_indices, segment_size)
|
ret = segment(x, segment_indices, segment_size, pad_short=pad_short)
|
||||||
return ret, segment_indices
|
return ret, segment_indices
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,73 +1,383 @@
|
||||||
# taken from https://github.com/Po-Hsun-Su/pytorch-ssim
|
# Adopted from https://github.com/photosynthesis-team/piq
|
||||||
|
|
||||||
from math import exp
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
|
||||||
def gaussian(window_size, sigma):
|
def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
||||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
r"""Reduce input in batch dimension if needed.
|
||||||
return gauss / gauss.sum()
|
Args:
|
||||||
|
x: Tensor with shape (N, *).
|
||||||
|
reduction: Specifies the reduction type:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
|
||||||
|
"""
|
||||||
|
if reduction == "none":
|
||||||
|
return x
|
||||||
|
if reduction == "mean":
|
||||||
|
return x.mean(dim=0)
|
||||||
|
if reduction == "sum":
|
||||||
|
return x.sum(dim=0)
|
||||||
|
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||||
|
|
||||||
|
|
||||||
def create_window(window_size, channel):
|
def _validate_input(
|
||||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
tensors: List[torch.Tensor],
|
||||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
dim_range: Tuple[int, int] = (0, -1),
|
||||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
data_range: Tuple[float, float] = (0.0, -1.0),
|
||||||
return window
|
# size_dim_range: Tuple[float, float] = (0., -1.),
|
||||||
|
size_range: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> None:
|
||||||
|
r"""Check that input(-s) satisfies the requirements
|
||||||
|
Args:
|
||||||
|
tensors: Tensors to check
|
||||||
|
dim_range: Allowed number of dimensions. (min, max)
|
||||||
|
data_range: Allowed range of values in tensors. (min, max)
|
||||||
|
size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not __debug__:
|
||||||
|
return
|
||||||
|
|
||||||
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
x = tensors[0]
|
||||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
|
||||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
|
||||||
|
|
||||||
# TODO: check if you need AMP disabled
|
for t in tensors:
|
||||||
# with torch.cuda.amp.autocast(enabled=False):
|
assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}"
|
||||||
mu1_sq = mu1.float().pow(2)
|
assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}"
|
||||||
mu2_sq = mu2.float().pow(2)
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
|
|
||||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
if size_range is None:
|
||||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}"
|
||||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
|
||||||
|
|
||||||
C1 = 0.01**2
|
|
||||||
C2 = 0.03**2
|
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
|
||||||
|
|
||||||
if size_average:
|
|
||||||
return ssim_map.mean()
|
|
||||||
return ssim_map.mean(1).mean(1).mean(1)
|
|
||||||
|
|
||||||
|
|
||||||
class SSIM(torch.nn.Module):
|
|
||||||
def __init__(self, window_size=11, size_average=True):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.size_average = size_average
|
|
||||||
self.channel = 1
|
|
||||||
self.window = create_window(window_size, self.channel)
|
|
||||||
|
|
||||||
def forward(self, img1, img2):
|
|
||||||
(_, channel, _, _) = img1.size()
|
|
||||||
|
|
||||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
|
||||||
window = self.window
|
|
||||||
else:
|
else:
|
||||||
window = create_window(self.window_size, channel)
|
assert (
|
||||||
window = window.type_as(img1)
|
t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]]
|
||||||
|
), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}"
|
||||||
|
|
||||||
self.window = window
|
if dim_range[0] == dim_range[1]:
|
||||||
self.channel = channel
|
assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}"
|
||||||
|
elif dim_range[0] < dim_range[1]:
|
||||||
|
assert (
|
||||||
|
dim_range[0] <= t.dim() <= dim_range[1]
|
||||||
|
), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}"
|
||||||
|
|
||||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
if data_range[0] < data_range[1]:
|
||||||
|
assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}"
|
||||||
|
assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}"
|
||||||
|
|
||||||
|
|
||||||
def ssim(img1, img2, window_size=11, size_average=True):
|
def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor:
|
||||||
(_, channel, _, _) = img1.size()
|
r"""Returns 2D Gaussian kernel N(0,`sigma`^2)
|
||||||
window = create_window(window_size, channel).type_as(img1)
|
Args:
|
||||||
window = window.type_as(img1)
|
size: Size of the kernel
|
||||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
sigma: Std of the distribution
|
||||||
|
Returns:
|
||||||
|
gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size)
|
||||||
|
"""
|
||||||
|
coords = torch.arange(kernel_size, dtype=torch.float32)
|
||||||
|
coords -= (kernel_size - 1) / 2.0
|
||||||
|
|
||||||
|
g = coords**2
|
||||||
|
g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp()
|
||||||
|
|
||||||
|
g /= g.sum()
|
||||||
|
return g.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def ssim(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
kernel_size: int = 11,
|
||||||
|
kernel_sigma: float = 1.5,
|
||||||
|
data_range: Union[int, float] = 1.0,
|
||||||
|
reduction: str = "mean",
|
||||||
|
full: bool = False,
|
||||||
|
downsample: bool = True,
|
||||||
|
k1: float = 0.01,
|
||||||
|
k2: float = 0.03,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
r"""Interface of Structural Similarity (SSIM) index.
|
||||||
|
Inputs supposed to be in range ``[0, data_range]``.
|
||||||
|
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||||
|
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||||
|
kernel_size: The side-length of the sliding window used in comparison. Must be an odd value.
|
||||||
|
kernel_sigma: Sigma of normal distribution.
|
||||||
|
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||||
|
reduction: Specifies the reduction type:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||||
|
full: Return cs map or not.
|
||||||
|
downsample: Perform average pool before SSIM computation. Default: True
|
||||||
|
k1: Algorithm parameter, K1 (small constant).
|
||||||
|
k2: Algorithm parameter, K2 (small constant).
|
||||||
|
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned
|
||||||
|
as a tensor of size 2.
|
||||||
|
|
||||||
|
References:
|
||||||
|
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||||
|
Image quality assessment: From error visibility to structural similarity.
|
||||||
|
IEEE Transactions on Image Processing, 13, 600-612.
|
||||||
|
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||||
|
DOI: `10.1109/TIP.2003.819861`
|
||||||
|
"""
|
||||||
|
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||||
|
_validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range))
|
||||||
|
|
||||||
|
x = x / float(data_range)
|
||||||
|
y = y / float(data_range)
|
||||||
|
|
||||||
|
# Averagepool image if the size is large enough
|
||||||
|
f = max(1, round(min(x.size()[-2:]) / 256))
|
||||||
|
if (f > 1) and downsample:
|
||||||
|
x = F.avg_pool2d(x, kernel_size=f)
|
||||||
|
y = F.avg_pool2d(y, kernel_size=f)
|
||||||
|
|
||||||
|
kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y)
|
||||||
|
_compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel
|
||||||
|
ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2)
|
||||||
|
ssim_val = ssim_map.mean(1)
|
||||||
|
cs = cs_map.mean(1)
|
||||||
|
|
||||||
|
ssim_val = _reduce(ssim_val, reduction)
|
||||||
|
cs = _reduce(cs, reduction)
|
||||||
|
|
||||||
|
if full:
|
||||||
|
return [ssim_val, cs]
|
||||||
|
|
||||||
|
return ssim_val
|
||||||
|
|
||||||
|
|
||||||
|
class SSIMLoss(_Loss):
|
||||||
|
r"""Creates a criterion that measures the structural similarity index error between
|
||||||
|
each element in the input :math:`x` and target :math:`y`.
|
||||||
|
|
||||||
|
To match performance with skimage and tensorflow set ``'downsample' = True``.
|
||||||
|
|
||||||
|
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\
|
||||||
|
ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)}
|
||||||
|
{(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)},
|
||||||
|
|
||||||
|
where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'``
|
||||||
|
(default ``'mean'``), then:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
SSIMLoss(x, y) =
|
||||||
|
\begin{cases}
|
||||||
|
\operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\
|
||||||
|
\operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.}
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
|
||||||
|
of :math:`n` elements each.
|
||||||
|
|
||||||
|
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||||
|
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
|
||||||
|
In case of 5D input tensors, complex value is returned as a tensor of size 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size: By default, the mean and covariance of a pixel is obtained
|
||||||
|
by convolution with given filter_size.
|
||||||
|
kernel_sigma: Standard deviation for Gaussian kernel.
|
||||||
|
k1: Coefficient related to c1 in the above equation.
|
||||||
|
k2: Coefficient related to c2 in the above equation.
|
||||||
|
downsample: Perform average pool before SSIM computation. Default: True
|
||||||
|
reduction: Specifies the reduction type:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
|
||||||
|
data_range: Maximum value range of images (usually 1.0 or 255).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> loss = SSIMLoss()
|
||||||
|
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
|
||||||
|
>>> y = torch.rand(3, 3, 256, 256)
|
||||||
|
>>> output = loss(x, y)
|
||||||
|
>>> output.backward()
|
||||||
|
|
||||||
|
References:
|
||||||
|
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||||
|
Image quality assessment: From error visibility to structural similarity.
|
||||||
|
IEEE Transactions on Image Processing, 13, 600-612.
|
||||||
|
https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
|
||||||
|
DOI:`10.1109/TIP.2003.819861`
|
||||||
|
"""
|
||||||
|
__constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int = 11,
|
||||||
|
kernel_sigma: float = 1.5,
|
||||||
|
k1: float = 0.01,
|
||||||
|
k2: float = 0.03,
|
||||||
|
downsample: bool = True,
|
||||||
|
reduction: str = "mean",
|
||||||
|
data_range: Union[int, float] = 1.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Generic loss parameters.
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
# Loss-specific parameters.
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
# This check might look redundant because kernel size is checked within the ssim function anyway.
|
||||||
|
# However, this check allows to fail fast when the loss is being initialised and training has not been started.
|
||||||
|
assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]"
|
||||||
|
self.kernel_sigma = kernel_sigma
|
||||||
|
self.k1 = k1
|
||||||
|
self.k2 = k2
|
||||||
|
self.downsample = downsample
|
||||||
|
self.data_range = data_range
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""Computation of Structural Similarity (SSIM) index as a loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||||
|
y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors,
|
||||||
|
complex value is returned as a tensor of size 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score = ssim(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
kernel_sigma=self.kernel_sigma,
|
||||||
|
downsample=self.downsample,
|
||||||
|
data_range=self.data_range,
|
||||||
|
reduction=self.reduction,
|
||||||
|
full=False,
|
||||||
|
k1=self.k1,
|
||||||
|
k2=self.k2,
|
||||||
|
)
|
||||||
|
return torch.ones_like(score) - score
|
||||||
|
|
||||||
|
|
||||||
|
def _ssim_per_channel(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
kernel: torch.Tensor,
|
||||||
|
k1: float = 0.01,
|
||||||
|
k2: float = 0.03,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
r"""Calculate Structural Similarity (SSIM) index for X and Y per channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: An input tensor. Shape :math:`(N, C, H, W)`.
|
||||||
|
y: A target tensor. Shape :math:`(N, C, H, W)`.
|
||||||
|
kernel: 2D Gaussian kernel.
|
||||||
|
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||||
|
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||||
|
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full Value of Structural Similarity (SSIM) index.
|
||||||
|
"""
|
||||||
|
if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2):
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||||
|
f"Kernel size: {kernel.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
c1 = k1**2
|
||||||
|
c2 = k2**2
|
||||||
|
n_channels = x.size(1)
|
||||||
|
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
|
||||||
|
mu_xx = mu_x**2
|
||||||
|
mu_yy = mu_y**2
|
||||||
|
mu_xy = mu_x * mu_y
|
||||||
|
|
||||||
|
sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
|
||||||
|
sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
|
||||||
|
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
|
||||||
|
|
||||||
|
# Contrast sensitivity (CS) with alpha = beta = gamma = 1.
|
||||||
|
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
|
||||||
|
|
||||||
|
# Structural similarity (SSIM)
|
||||||
|
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
|
||||||
|
|
||||||
|
ssim_val = ss.mean(dim=(-1, -2))
|
||||||
|
cs = cs.mean(dim=(-1, -2))
|
||||||
|
return ssim_val, cs
|
||||||
|
|
||||||
|
|
||||||
|
def _ssim_per_channel_complex(
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
kernel: torch.Tensor,
|
||||||
|
k1: float = 0.01,
|
||||||
|
k2: float = 0.03,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: An input tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||||
|
y: A target tensor. Shape :math:`(N, C, H, W, 2)`.
|
||||||
|
kernel: 2-D gauss kernel.
|
||||||
|
k1: Algorithm parameter, K1 (small constant, see [1]).
|
||||||
|
k2: Algorithm parameter, K2 (small constant, see [1]).
|
||||||
|
Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full Value of Complex Structural Similarity (SSIM) index.
|
||||||
|
"""
|
||||||
|
n_channels = x.size(1)
|
||||||
|
if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2):
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel size can't be greater than actual input size. Input size: {x.size()}. "
|
||||||
|
f"Kernel size: {kernel.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
c1 = k1**2
|
||||||
|
c2 = k2**2
|
||||||
|
|
||||||
|
x_real = x[..., 0]
|
||||||
|
x_imag = x[..., 1]
|
||||||
|
y_real = y[..., 0]
|
||||||
|
y_imag = y[..., 1]
|
||||||
|
|
||||||
|
mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels)
|
||||||
|
|
||||||
|
mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2)
|
||||||
|
mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2)
|
||||||
|
mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag
|
||||||
|
mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real
|
||||||
|
|
||||||
|
compensation = 1.0
|
||||||
|
|
||||||
|
x_sq = x_real.pow(2) + x_imag.pow(2)
|
||||||
|
y_sq = y_real.pow(2) + y_imag.pow(2)
|
||||||
|
x_y_real = x_real * y_real - x_imag * y_imag
|
||||||
|
x_y_imag = x_real * y_imag + x_imag * y_real
|
||||||
|
|
||||||
|
sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq
|
||||||
|
sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq
|
||||||
|
sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real
|
||||||
|
sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag
|
||||||
|
sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1)
|
||||||
|
mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1)
|
||||||
|
# Set alpha = beta = gamma = 1.
|
||||||
|
cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation)
|
||||||
|
ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation)
|
||||||
|
ssim_map = ssim_map * cs_map
|
||||||
|
|
||||||
|
ssim_val = ssim_map.mean(dim=(-2, -3))
|
||||||
|
cs = cs_map.mean(dim=(-2, -3))
|
||||||
|
|
||||||
|
return ssim_val, cs
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
@ -163,6 +164,13 @@ class ESpeak(BasePhonemizer):
|
||||||
|
|
||||||
# dealing with the conditions descrived above
|
# dealing with the conditions descrived above
|
||||||
ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:]
|
ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:]
|
||||||
|
|
||||||
|
# espeak-ng backend can add language flags that need to be removed:
|
||||||
|
# "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ."
|
||||||
|
# phonemize needs to remove the language flags of the returned text:
|
||||||
|
# "sɛʁtˈɛ̃ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ."
|
||||||
|
ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded)
|
||||||
|
|
||||||
phonemes += ph_decoded.strip()
|
phonemes += ph_decoded.strip()
|
||||||
return phonemes.replace("_", separator)
|
return phonemes.replace("_", separator)
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ class Punctuation:
|
||||||
|
|
||||||
# nothing have been phonemized, returns the puncs alone
|
# nothing have been phonemized, returns the puncs alone
|
||||||
if not text:
|
if not text:
|
||||||
return ["".join(m.mark for m in puncs)]
|
return ["".join(m.punc for m in puncs)]
|
||||||
|
|
||||||
current = puncs[0]
|
current = puncs[0]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from TTS.utils.audio.processor import AudioProcessor
|
|
@ -0,0 +1,425 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import pyworld as pw
|
||||||
|
import scipy
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
# For using kwargs
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
|
||||||
|
|
||||||
|
def build_mel_basis(
|
||||||
|
*,
|
||||||
|
sample_rate: int = None,
|
||||||
|
fft_size: int = None,
|
||||||
|
num_mels: int = None,
|
||||||
|
mel_fmax: int = None,
|
||||||
|
mel_fmin: int = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Build melspectrogram basis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: melspectrogram basis.
|
||||||
|
"""
|
||||||
|
if mel_fmax is not None:
|
||||||
|
assert mel_fmax <= sample_rate // 2
|
||||||
|
assert mel_fmax - mel_fmin > 0
|
||||||
|
return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax)
|
||||||
|
|
||||||
|
|
||||||
|
def millisec_to_length(
|
||||||
|
*, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Compute hop and window length from milliseconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: hop length and window length for STFT.
|
||||||
|
"""
|
||||||
|
factor = frame_length_ms / frame_shift_ms
|
||||||
|
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||||
|
win_length = int(frame_length_ms / 1000.0 * sample_rate)
|
||||||
|
hop_length = int(win_length / float(factor))
|
||||||
|
return win_length, hop_length
|
||||||
|
|
||||||
|
|
||||||
|
def _log(x, base):
|
||||||
|
if base == 10:
|
||||||
|
return np.log10(x)
|
||||||
|
return np.log(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _exp(x, base):
|
||||||
|
if base == 10:
|
||||||
|
return np.power(10, x)
|
||||||
|
return np.exp(x)
|
||||||
|
|
||||||
|
|
||||||
|
def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert amplitude values to decibels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Amplitude spectrogram.
|
||||||
|
gain (float): Gain factor. Defaults to 1.
|
||||||
|
base (int): Logarithm base. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Decibels spectrogram.
|
||||||
|
"""
|
||||||
|
assert (x < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||||
|
return gain * _log(np.maximum(1e-8, x), base)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Decibels spectrogram.
|
||||||
|
gain (float): Gain factor. Defaults to 1.
|
||||||
|
base (int): Logarithm base. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Amplitude spectrogram.
|
||||||
|
"""
|
||||||
|
return _exp(x / gain, base)
|
||||||
|
|
||||||
|
|
||||||
|
def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||||
|
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Audio signal.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Preemphasis coeff is set to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Decorrelated audio signal.
|
||||||
|
"""
|
||||||
|
if coef == 0:
|
||||||
|
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||||
|
return scipy.signal.lfilter([1, -coef], [1], x)
|
||||||
|
|
||||||
|
|
||||||
|
def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray:
|
||||||
|
"""Reverse pre-emphasis."""
|
||||||
|
if coef == 0:
|
||||||
|
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||||
|
return scipy.signal.lfilter([1], [1, -coef], x)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- spec: :math:`[C, T]`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Normalized melspectrogram.
|
||||||
|
"""
|
||||||
|
return np.dot(mel_basis, spec)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert a melspectrogram to full scale spectrogram."""
|
||||||
|
assert (mel < 0).sum() == 0, " [!] Input values must be non-negative."
|
||||||
|
inv_mel_basis = np.linalg.pinv(mel_basis)
|
||||||
|
return np.maximum(1e-10, np.dot(inv_mel_basis, mel))
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray:
|
||||||
|
"""Compute a spectrogram from a waveform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length`
|
||||||
|
"""
|
||||||
|
D = stft(y=wav, **kwargs)
|
||||||
|
S = np.abs(D)
|
||||||
|
return S.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray:
|
||||||
|
"""Compute a melspectrogram from a waveform."""
|
||||||
|
D = stft(y=wav, **kwargs)
|
||||||
|
S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs)
|
||||||
|
return S.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
|
S = spec.copy()
|
||||||
|
return griffin_lim(spec=S**power, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray:
|
||||||
|
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||||
|
S = mel.copy()
|
||||||
|
S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear
|
||||||
|
return griffin_lim(spec=S**power, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
### STFT and ISTFT ###
|
||||||
|
def stft(
|
||||||
|
*,
|
||||||
|
y: np.ndarray = None,
|
||||||
|
fft_size: int = None,
|
||||||
|
hop_length: int = None,
|
||||||
|
win_length: int = None,
|
||||||
|
pad_mode: str = "reflect",
|
||||||
|
window: str = "hann",
|
||||||
|
center: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Librosa STFT wrapper.
|
||||||
|
|
||||||
|
Check http://librosa.org/doc/main/generated/librosa.stft.html argument details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Complex number array.
|
||||||
|
"""
|
||||||
|
return librosa.stft(
|
||||||
|
y=y,
|
||||||
|
n_fft=fft_size,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
window=window,
|
||||||
|
center=center,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def istft(
|
||||||
|
*,
|
||||||
|
y: np.ndarray = None,
|
||||||
|
fft_size: int = None,
|
||||||
|
hop_length: int = None,
|
||||||
|
win_length: int = None,
|
||||||
|
window: str = "hann",
|
||||||
|
center: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Librosa iSTFT wrapper.
|
||||||
|
|
||||||
|
Check http://librosa.org/doc/main/generated/librosa.istft.html argument details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Complex number array.
|
||||||
|
"""
|
||||||
|
return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window)
|
||||||
|
|
||||||
|
|
||||||
|
def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray:
|
||||||
|
angles = np.exp(2j * np.pi * np.random.rand(*spec.shape))
|
||||||
|
S_complex = np.abs(spec).astype(np.complex)
|
||||||
|
y = istft(y=S_complex * angles, **kwargs)
|
||||||
|
if not np.isfinite(y).all():
|
||||||
|
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||||
|
return np.array([0.0])
|
||||||
|
for _ in range(num_iter):
|
||||||
|
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||||
|
y = istft(y=S_complex * angles, **kwargs)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def compute_stft_paddings(
|
||||||
|
*, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||||
|
(first and final frames)"""
|
||||||
|
pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0]
|
||||||
|
if not pad_two_sides:
|
||||||
|
return 0, pad
|
||||||
|
return pad // 2, pad // 2 + pad % 2
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f0(
|
||||||
|
*, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||||
|
>>> from TTS.config import BaseAudioConfig
|
||||||
|
>>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000)
|
||||||
|
>>> ap = AudioProcessor(**conf)
|
||||||
|
>>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
|
||||||
|
>>> pitch = ap.compute_f0(wav)
|
||||||
|
"""
|
||||||
|
assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`."
|
||||||
|
|
||||||
|
f0, t = pw.dio(
|
||||||
|
x.astype(np.double),
|
||||||
|
fs=sample_rate,
|
||||||
|
f0_ceil=pitch_fmax,
|
||||||
|
frame_period=1000 * hop_length / sample_rate,
|
||||||
|
)
|
||||||
|
f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate)
|
||||||
|
return f0
|
||||||
|
|
||||||
|
|
||||||
|
### Audio Processing ###
|
||||||
|
def find_endpoint(
|
||||||
|
*,
|
||||||
|
wav: np.ndarray = None,
|
||||||
|
trim_db: float = -40,
|
||||||
|
sample_rate: int = None,
|
||||||
|
min_silence_sec=0.8,
|
||||||
|
gain: float = None,
|
||||||
|
base: int = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> int:
|
||||||
|
"""Find the last point without silence at the end of a audio signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (np.ndarray): Audio signal.
|
||||||
|
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||||
|
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||||
|
gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None.
|
||||||
|
base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Last point without silence.
|
||||||
|
"""
|
||||||
|
window_length = int(sample_rate * min_silence_sec)
|
||||||
|
hop_length = int(window_length / 4)
|
||||||
|
threshold = db_to_amp(x=-trim_db, gain=gain, base=base)
|
||||||
|
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||||
|
if np.max(wav[x : x + window_length]) < threshold:
|
||||||
|
return x + hop_length
|
||||||
|
return len(wav)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_silence(
|
||||||
|
*,
|
||||||
|
wav: np.ndarray = None,
|
||||||
|
sample_rate: int = None,
|
||||||
|
trim_db: float = None,
|
||||||
|
win_length: int = None,
|
||||||
|
hop_length: int = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Trim silent parts with a threshold and 0.01 sec margin"""
|
||||||
|
margin = int(sample_rate * 0.01)
|
||||||
|
wav = wav[margin:-margin]
|
||||||
|
return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray:
|
||||||
|
"""Normalize the volume of an audio signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Raw waveform.
|
||||||
|
coef (float): Coefficient to rescale the maximum value. Defaults to 0.95.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Volume normalized waveform.
|
||||||
|
"""
|
||||||
|
return x / abs(x).max() * coef
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||||
|
r = 10 ** (db_level / 20)
|
||||||
|
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
|
||||||
|
return wav * a
|
||||||
|
|
||||||
|
|
||||||
|
def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray:
|
||||||
|
"""Normalize the volume based on RMS of the signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Raw waveform.
|
||||||
|
db_level (float): Target dB level in RMS. Defaults to -27.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: RMS normalized waveform.
|
||||||
|
"""
|
||||||
|
assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0"
|
||||||
|
wav = rms_norm(wav=x, db_level=db_level)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray:
|
||||||
|
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||||
|
|
||||||
|
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): Path to the wav file.
|
||||||
|
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||||
|
resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Loaded waveform.
|
||||||
|
"""
|
||||||
|
if resample:
|
||||||
|
# loading with resampling. It is significantly slower.
|
||||||
|
x, _ = librosa.load(filename, sr=sample_rate)
|
||||||
|
else:
|
||||||
|
# SF is faster than librosa for loading files
|
||||||
|
x, _ = sf.read(filename)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
|
||||||
|
"""Save float waveform to a file using Scipy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
|
||||||
|
path (str): Path to a output file.
|
||||||
|
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||||
|
"""
|
||||||
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||||
|
scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
|
||||||
|
|
||||||
|
|
||||||
|
def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||||
|
mu = 2**mulaw_qc - 1
|
||||||
|
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||||
|
signal = (signal + 1) / 2 * mu + 0.5
|
||||||
|
return np.floor(
|
||||||
|
signal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray:
|
||||||
|
"""Recovers waveform from quantized values."""
|
||||||
|
mu = 2**mulaw_qc - 1
|
||||||
|
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray:
|
||||||
|
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray:
|
||||||
|
"""Quantize a waveform to a given number of bits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||||
|
quantize_bits (int): Number of quantization bits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Quantized waveform.
|
||||||
|
"""
|
||||||
|
return (x + 1.0) * (2**quantize_bits - 1) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray:
|
||||||
|
"""Dequantize a waveform from the given number of bits."""
|
||||||
|
return 2 * x / (2**quantize_bits - 1) - 1
|
|
@ -6,179 +6,14 @@ import pyworld as pw
|
||||||
import scipy.io.wavfile
|
import scipy.io.wavfile
|
||||||
import scipy.signal
|
import scipy.signal
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import StandardScaler
|
from TTS.tts.utils.helpers import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|
||||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
|
||||||
|
|
||||||
TODO: Merge this with audio.py
|
|
||||||
|
|
||||||
Args:
|
|
||||||
|
|
||||||
n_fft (int):
|
|
||||||
FFT window size for STFT.
|
|
||||||
|
|
||||||
hop_length (int):
|
|
||||||
number of frames between STFT columns.
|
|
||||||
|
|
||||||
win_length (int, optional):
|
|
||||||
STFT window length.
|
|
||||||
|
|
||||||
pad_wav (bool, optional):
|
|
||||||
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
|
||||||
|
|
||||||
window (str, optional):
|
|
||||||
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
|
||||||
|
|
||||||
sample_rate (int, optional):
|
|
||||||
target audio sampling rate. Defaults to None.
|
|
||||||
|
|
||||||
mel_fmin (int, optional):
|
|
||||||
minimum filter frequency for computing melspectrograms. Defaults to None.
|
|
||||||
|
|
||||||
mel_fmax (int, optional):
|
|
||||||
maximum filter frequency for computing melspectrograms. Defaults to None.
|
|
||||||
|
|
||||||
n_mels (int, optional):
|
|
||||||
number of melspectrogram dimensions. Defaults to None.
|
|
||||||
|
|
||||||
use_mel (bool, optional):
|
|
||||||
If True compute the melspectrograms otherwise. Defaults to False.
|
|
||||||
|
|
||||||
do_amp_to_db_linear (bool, optional):
|
|
||||||
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
|
||||||
|
|
||||||
spec_gain (float, optional):
|
|
||||||
gain applied when converting amplitude to DB. Defaults to 1.0.
|
|
||||||
|
|
||||||
power (float, optional):
|
|
||||||
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
|
||||||
|
|
||||||
use_htk (bool, optional):
|
|
||||||
Use HTK formula in mel filter instead of Slaney.
|
|
||||||
|
|
||||||
mel_norm (None, 'slaney', or number, optional):
|
|
||||||
If 'slaney', divide the triangular mel weights by the width of the mel band
|
|
||||||
(area normalization).
|
|
||||||
|
|
||||||
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
|
||||||
See `librosa.util.normalize` for a full description of supported norm values
|
|
||||||
(including `+-np.inf`).
|
|
||||||
|
|
||||||
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_fft,
|
|
||||||
hop_length,
|
|
||||||
win_length,
|
|
||||||
pad_wav=False,
|
|
||||||
window="hann_window",
|
|
||||||
sample_rate=None,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
n_mels=80,
|
|
||||||
use_mel=False,
|
|
||||||
do_amp_to_db=False,
|
|
||||||
spec_gain=1.0,
|
|
||||||
power=None,
|
|
||||||
use_htk=False,
|
|
||||||
mel_norm="slaney",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.n_fft = n_fft
|
|
||||||
self.hop_length = hop_length
|
|
||||||
self.win_length = win_length
|
|
||||||
self.pad_wav = pad_wav
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.mel_fmin = mel_fmin
|
|
||||||
self.mel_fmax = mel_fmax
|
|
||||||
self.n_mels = n_mels
|
|
||||||
self.use_mel = use_mel
|
|
||||||
self.do_amp_to_db = do_amp_to_db
|
|
||||||
self.spec_gain = spec_gain
|
|
||||||
self.power = power
|
|
||||||
self.use_htk = use_htk
|
|
||||||
self.mel_norm = mel_norm
|
|
||||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
|
||||||
self.mel_basis = None
|
|
||||||
if use_mel:
|
|
||||||
self._build_mel_basis()
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
"""Compute spectrogram frames by torch based stft.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): input waveform
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: spectrogram frames.
|
|
||||||
|
|
||||||
Shapes:
|
|
||||||
x: [B x T] or [:math:`[B, 1, T]`]
|
|
||||||
"""
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = x.unsqueeze(1)
|
|
||||||
if self.pad_wav:
|
|
||||||
padding = int((self.n_fft - self.hop_length) / 2)
|
|
||||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
|
||||||
# B x D x T x 2
|
|
||||||
o = torch.stft(
|
|
||||||
x.squeeze(1),
|
|
||||||
self.n_fft,
|
|
||||||
self.hop_length,
|
|
||||||
self.win_length,
|
|
||||||
self.window,
|
|
||||||
center=True,
|
|
||||||
pad_mode="reflect", # compatible with audio.py
|
|
||||||
normalized=False,
|
|
||||||
onesided=True,
|
|
||||||
return_complex=False,
|
|
||||||
)
|
|
||||||
M = o[:, :, :, 0]
|
|
||||||
P = o[:, :, :, 1]
|
|
||||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
|
||||||
|
|
||||||
if self.power is not None:
|
|
||||||
S = S**self.power
|
|
||||||
|
|
||||||
if self.use_mel:
|
|
||||||
S = torch.matmul(self.mel_basis.to(x), S)
|
|
||||||
if self.do_amp_to_db:
|
|
||||||
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
|
||||||
return S
|
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
|
||||||
mel_basis = librosa.filters.mel(
|
|
||||||
self.sample_rate,
|
|
||||||
self.n_fft,
|
|
||||||
n_mels=self.n_mels,
|
|
||||||
fmin=self.mel_fmin,
|
|
||||||
fmax=self.mel_fmax,
|
|
||||||
htk=self.use_htk,
|
|
||||||
norm=self.mel_norm,
|
|
||||||
)
|
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _amp_to_db(x, spec_gain=1.0):
|
|
||||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _db_to_amp(x, spec_gain=1.0):
|
|
||||||
return torch.exp(x) / spec_gain
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
# pylint: disable=too-many-public-methods
|
||||||
class AudioProcessor(object):
|
|
||||||
"""Audio Processor for TTS used by all the data pipelines.
|
|
||||||
|
|
||||||
TODO: Make this a dataclass to replace `BaseAudioConfig`.
|
|
||||||
|
class AudioProcessor(object):
|
||||||
|
"""Audio Processor for TTS.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
All the class arguments are set to default values to enable a flexible initialization
|
All the class arguments are set to default values to enable a flexible initialization
|
|
@ -0,0 +1,163 @@
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
n_fft (int):
|
||||||
|
FFT window size for STFT.
|
||||||
|
|
||||||
|
hop_length (int):
|
||||||
|
number of frames between STFT columns.
|
||||||
|
|
||||||
|
win_length (int, optional):
|
||||||
|
STFT window length.
|
||||||
|
|
||||||
|
pad_wav (bool, optional):
|
||||||
|
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False.
|
||||||
|
|
||||||
|
window (str, optional):
|
||||||
|
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window"
|
||||||
|
|
||||||
|
sample_rate (int, optional):
|
||||||
|
target audio sampling rate. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmin (int, optional):
|
||||||
|
minimum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
mel_fmax (int, optional):
|
||||||
|
maximum filter frequency for computing melspectrograms. Defaults to None.
|
||||||
|
|
||||||
|
n_mels (int, optional):
|
||||||
|
number of melspectrogram dimensions. Defaults to None.
|
||||||
|
|
||||||
|
use_mel (bool, optional):
|
||||||
|
If True compute the melspectrograms otherwise. Defaults to False.
|
||||||
|
|
||||||
|
do_amp_to_db_linear (bool, optional):
|
||||||
|
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False.
|
||||||
|
|
||||||
|
spec_gain (float, optional):
|
||||||
|
gain applied when converting amplitude to DB. Defaults to 1.0.
|
||||||
|
|
||||||
|
power (float, optional):
|
||||||
|
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None.
|
||||||
|
|
||||||
|
use_htk (bool, optional):
|
||||||
|
Use HTK formula in mel filter instead of Slaney.
|
||||||
|
|
||||||
|
mel_norm (None, 'slaney', or number, optional):
|
||||||
|
If 'slaney', divide the triangular mel weights by the width of the mel band
|
||||||
|
(area normalization).
|
||||||
|
|
||||||
|
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm.
|
||||||
|
See `librosa.util.normalize` for a full description of supported norm values
|
||||||
|
(including `+-np.inf`).
|
||||||
|
|
||||||
|
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft,
|
||||||
|
hop_length,
|
||||||
|
win_length,
|
||||||
|
pad_wav=False,
|
||||||
|
window="hann_window",
|
||||||
|
sample_rate=None,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
n_mels=80,
|
||||||
|
use_mel=False,
|
||||||
|
do_amp_to_db=False,
|
||||||
|
spec_gain=1.0,
|
||||||
|
power=None,
|
||||||
|
use_htk=False,
|
||||||
|
mel_norm="slaney",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.pad_wav = pad_wav
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.mel_fmin = mel_fmin
|
||||||
|
self.mel_fmax = mel_fmax
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.use_mel = use_mel
|
||||||
|
self.do_amp_to_db = do_amp_to_db
|
||||||
|
self.spec_gain = spec_gain
|
||||||
|
self.power = power
|
||||||
|
self.use_htk = use_htk
|
||||||
|
self.mel_norm = mel_norm
|
||||||
|
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||||
|
self.mel_basis = None
|
||||||
|
if use_mel:
|
||||||
|
self._build_mel_basis()
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
"""Compute spectrogram frames by torch based stft.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: spectrogram frames.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: [B x T] or [:math:`[B, 1, T]`]
|
||||||
|
"""
|
||||||
|
if x.ndim == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
if self.pad_wav:
|
||||||
|
padding = int((self.n_fft - self.hop_length) / 2)
|
||||||
|
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||||
|
# B x D x T x 2
|
||||||
|
o = torch.stft(
|
||||||
|
x.squeeze(1),
|
||||||
|
self.n_fft,
|
||||||
|
self.hop_length,
|
||||||
|
self.win_length,
|
||||||
|
self.window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect", # compatible with audio.py
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
M = o[:, :, :, 0]
|
||||||
|
P = o[:, :, :, 1]
|
||||||
|
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||||
|
|
||||||
|
if self.power is not None:
|
||||||
|
S = S**self.power
|
||||||
|
|
||||||
|
if self.use_mel:
|
||||||
|
S = torch.matmul(self.mel_basis.to(x), S)
|
||||||
|
if self.do_amp_to_db:
|
||||||
|
S = self._amp_to_db(S, spec_gain=self.spec_gain)
|
||||||
|
return S
|
||||||
|
|
||||||
|
def _build_mel_basis(self):
|
||||||
|
mel_basis = librosa.filters.mel(
|
||||||
|
self.sample_rate,
|
||||||
|
self.n_fft,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
fmin=self.mel_fmin,
|
||||||
|
fmax=self.mel_fmax,
|
||||||
|
htk=self.use_htk,
|
||||||
|
norm=self.mel_norm,
|
||||||
|
)
|
||||||
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _amp_to_db(x, spec_gain=1.0):
|
||||||
|
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _db_to_amp(x, spec_gain=1.0):
|
||||||
|
return torch.exp(x) / spec_gain
|
|
@ -34,6 +34,8 @@ class CapacitronOptimizer:
|
||||||
self.primary_optimizer.zero_grad()
|
self.primary_optimizer.zero_grad()
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
# Update param groups to display the correct learning rate
|
||||||
|
self.param_groups = self.primary_optimizer.param_groups
|
||||||
self.primary_optimizer.step()
|
self.primary_optimizer.step()
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
|
@ -7,6 +6,7 @@ from shutil import copyfile, rmtree
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
@ -337,11 +337,20 @@ class ModelManager(object):
|
||||||
def _download_zip_file(file_url, output_folder):
|
def _download_zip_file(file_url, output_folder):
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url)
|
r = requests.get(file_url, stream=True)
|
||||||
# extract the file
|
# extract the file
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
|
block_size = 1024 # 1 Kibibyte
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
|
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||||
|
with open(temp_zip_name, "wb") as file:
|
||||||
|
for data in r.iter_content(block_size):
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
file.write(data)
|
||||||
|
with zipfile.ZipFile(temp_zip_name) as z:
|
||||||
z.extractall(output_folder)
|
z.extractall(output_folder)
|
||||||
|
os.remove(temp_zip_name) # delete zip after extract
|
||||||
except zipfile.BadZipFile:
|
except zipfile.BadZipFile:
|
||||||
print(f" > Error: Bad zip file - {file_url}")
|
print(f" > Error: Bad zip file - {file_url}")
|
||||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
|
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||||
|
|
||||||
|
|
||||||
class SubsetSampler(Sampler):
|
class SubsetSampler(Sampler):
|
||||||
|
@ -112,3 +114,89 @@ class PerfectBatchSampler(Sampler):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SortedSampler(Sampler):
|
||||||
|
"""Samples elements sequentially, always in the same order.
|
||||||
|
|
||||||
|
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (iterable): Iterable data.
|
||||||
|
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||||
|
numerical comparison key from each list element.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||||
|
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, sort_key: Callable = identity):
|
||||||
|
super().__init__(data)
|
||||||
|
self.data = data
|
||||||
|
self.sort_key = sort_key
|
||||||
|
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||||
|
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||||
|
self.sorted_indexes = [item[0] for item in zip_]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.sorted_indexes)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
class BucketBatchSampler(BatchSampler):
|
||||||
|
"""Bucket batch sampler
|
||||||
|
|
||||||
|
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampler (torch.data.utils.sampler.Sampler):
|
||||||
|
batch_size (int): Size of mini-batch.
|
||||||
|
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||||
|
than `batch_size`.
|
||||||
|
data (list): List of data samples.
|
||||||
|
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||||
|
bucket_size_multiplier (int, optional): Buckets are of size
|
||||||
|
`batch_size * bucket_size_multiplier`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampler,
|
||||||
|
data,
|
||||||
|
batch_size,
|
||||||
|
drop_last,
|
||||||
|
sort_key: Union[Callable, List] = identity,
|
||||||
|
bucket_size_multiplier=100,
|
||||||
|
):
|
||||||
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
|
self.data = data
|
||||||
|
self.sort_key = sort_key
|
||||||
|
_bucket_size = batch_size * bucket_size_multiplier
|
||||||
|
if hasattr(sampler, "__len__"):
|
||||||
|
_bucket_size = min(_bucket_size, len(sampler))
|
||||||
|
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for idxs in self.bucket_sampler:
|
||||||
|
bucket_data = [self.data[idx] for idx in idxs]
|
||||||
|
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||||
|
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||||
|
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||||
|
yield sorted_idxs
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.drop_last:
|
||||||
|
return len(self.sampler) // self.batch_size
|
||||||
|
return math.ceil(len(self.sampler) / self.batch_size)
|
|
@ -321,7 +321,7 @@ class Synthesizer(object):
|
||||||
waveform = waveform.squeeze()
|
waveform = waveform.squeeze()
|
||||||
|
|
||||||
# trim silence
|
# trim silence
|
||||||
if self.tts_config.audio["do_trim_silence"] is True:
|
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||||
|
|
||||||
wavs += list(waveform)
|
wavs += list(waveform)
|
||||||
|
|
|
@ -149,4 +149,4 @@ class WaveGradDataset(Dataset):
|
||||||
mels[idx, :, : mel.shape[1]] = mel
|
mels[idx, :, : mel.shape[1]] = mel
|
||||||
audios[idx, : audio.shape[0]] = audio
|
audios[idx, : audio.shape[0]] = audio
|
||||||
|
|
||||||
return audios, mels
|
return mels, audios
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.utils.audio import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
|
|
|
@ -185,8 +185,7 @@ class GAN(BaseVocoder):
|
||||||
outputs = {"model_outputs": self.y_hat_g}
|
outputs = {"model_outputs": self.y_hat_g}
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
@staticmethod
|
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
|
||||||
"""Logging shared by the training and evaluation.
|
"""Logging shared by the training and evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -198,7 +197,7 @@ class GAN(BaseVocoder):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: log figures and audio samples.
|
Tuple[Dict, Dict]: log figures and audio samples.
|
||||||
"""
|
"""
|
||||||
y_hat = outputs[0]["model_outputs"] if outputs[0] is not None else outputs[1]["model_outputs"]
|
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||||
y = batch["waveform"]
|
y = batch["waveform"]
|
||||||
figures = plot_results(y_hat, y, ap, name)
|
figures = plot_results(y_hat, y, ap, name)
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import spectral_norm, weight_norm
|
from torch.nn.utils import spectral_norm, weight_norm
|
||||||
|
|
||||||
from TTS.utils.audio import TorchSTFT
|
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
|
@ -233,6 +233,7 @@ class Wavernn(BaseVocoder):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
raise RuntimeError("Unknown model mode value - ", self.args.mode)
|
||||||
|
|
||||||
|
self.ap = AudioProcessor(**config.audio.to_dict())
|
||||||
self.aux_dims = self.args.res_out_dims // 4
|
self.aux_dims = self.args.res_out_dims // 4
|
||||||
|
|
||||||
if self.args.use_upsample_net:
|
if self.args.use_upsample_net:
|
||||||
|
@ -571,7 +572,7 @@ class Wavernn(BaseVocoder):
|
||||||
def test(
|
def test(
|
||||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
ap = assets["audio_processor"]
|
ap = self.ap
|
||||||
figures = {}
|
figures = {}
|
||||||
audios = {}
|
audios = {}
|
||||||
samples = test_loader.dataset.load_test_samples(1)
|
samples = test_loader.dataset.load_test_samples(1)
|
||||||
|
@ -587,8 +588,16 @@ class Wavernn(BaseVocoder):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
audios.update({f"test_{idx}/audio": y_hat})
|
audios.update({f"test_{idx}/audio": y_hat})
|
||||||
|
# audios.update({f"real_{idx}/audio": y_hat})
|
||||||
return figures, audios
|
return figures, audios
|
||||||
|
|
||||||
|
def test_log(
|
||||||
|
self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[Dict, np.ndarray]:
|
||||||
|
figures, audios = outputs
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_batch(batch: Dict) -> Dict:
|
def format_batch(batch: Dict) -> Dict:
|
||||||
waveform = batch[0]
|
waveform = batch[0]
|
||||||
|
@ -605,7 +614,7 @@ class Wavernn(BaseVocoder):
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
):
|
):
|
||||||
ap = assets["audio_processor"]
|
ap = self.ap
|
||||||
dataset = WaveRNNDataset(
|
dataset = WaveRNNDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
items=samples,
|
items=samples,
|
||||||
|
|
|
@ -45,7 +45,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"NUM_PROC = 8\n",
|
"NUM_PROC = 8\n",
|
||||||
"DATASET_CONFIG = BaseDatasetConfig(\n",
|
"DATASET_CONFIG = BaseDatasetConfig(\n",
|
||||||
" name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/home/ubuntu/TTS/depot/data/male_dataset1_44k/\"\n",
|
" name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/absolute/path/to/your/dataset/\"\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -58,13 +58,13 @@
|
||||||
"def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n",
|
"def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n",
|
||||||
" txt_file = os.path.join(root_path, meta_file)\n",
|
" txt_file = os.path.join(root_path, meta_file)\n",
|
||||||
" items = []\n",
|
" items = []\n",
|
||||||
" speaker_name = \"maledataset1\"\n",
|
" speaker_name = \"myspeaker\"\n",
|
||||||
" with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n",
|
" with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n",
|
||||||
" for line in ttf:\n",
|
" for line in ttf:\n",
|
||||||
" cols = line.split(\"|\")\n",
|
" cols = line.split(\"|\")\n",
|
||||||
" wav_file = os.path.join(root_path, \"wavs\", cols[0])\n",
|
" wav_file = os.path.join(root_path, \"wavs\", cols[0] + \".wav\") \n",
|
||||||
" text = cols[1]\n",
|
" text = cols[1]\n",
|
||||||
" items.append([text, wav_file, speaker_name])\n",
|
" items.append({\"text\": text, \"audio_file\": wav_file, \"speaker_name\": speaker_name})\n",
|
||||||
" return items"
|
" return items"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -78,7 +78,10 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n",
|
"# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n",
|
||||||
"train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n",
|
"train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n",
|
||||||
"items = train_samples + eval_samples\n",
|
"if eval_samples is not None:\n",
|
||||||
|
" items = train_samples + eval_samples\n",
|
||||||
|
"else:\n",
|
||||||
|
" items = train_samples\n",
|
||||||
"print(\" > Number of audio files: {}\".format(len(items)))\n",
|
"print(\" > Number of audio files: {}\".format(len(items)))\n",
|
||||||
"print(items[1])"
|
"print(items[1])"
|
||||||
]
|
]
|
||||||
|
@ -94,7 +97,7 @@
|
||||||
"# check wavs if exist\n",
|
"# check wavs if exist\n",
|
||||||
"wav_files = []\n",
|
"wav_files = []\n",
|
||||||
"for item in items:\n",
|
"for item in items:\n",
|
||||||
" wav_file = item[1].strip()\n",
|
" wav_file = item[\"audio_file\"].strip()\n",
|
||||||
" wav_files.append(wav_file)\n",
|
" wav_files.append(wav_file)\n",
|
||||||
" if not os.path.exists(wav_file):\n",
|
" if not os.path.exists(wav_file):\n",
|
||||||
" print(waf_path)"
|
" print(waf_path)"
|
||||||
|
@ -131,8 +134,8 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def load_item(item):\n",
|
"def load_item(item):\n",
|
||||||
" text = item[0].strip()\n",
|
" text = item[\"text\"].strip()\n",
|
||||||
" file_name = item[1].strip()\n",
|
" file_name = item[\"audio_file\"].strip()\n",
|
||||||
" audio, sr = librosa.load(file_name, sr=None)\n",
|
" audio, sr = librosa.load(file_name, sr=None)\n",
|
||||||
" audio_len = len(audio) / sr\n",
|
" audio_len = len(audio) / sr\n",
|
||||||
" text_len = len(text)\n",
|
" text_len = len(text)\n",
|
||||||
|
@ -416,7 +419,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.5"
|
"version": "3.9.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
"# set some vars\n",
|
"# set some vars\n",
|
||||||
"# TTS_PATH = \"/home/thorsten/___dev/tts/mozilla/TTS\"\n",
|
"# TTS_PATH = \"/home/thorsten/___dev/tts/mozilla/TTS\"\n",
|
||||||
"CONFIG_FILE = \"/path/to/config/config.json\"\n",
|
"CONFIG_FILE = \"/path/to/config/config.json\"\n",
|
||||||
"CHARS_TO_REMOVE = \".,:!?'\""
|
"CHARS_TO_REMOVE = \".,:!?'\"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -59,7 +59,8 @@
|
||||||
"# extra imports that might not be included in requirements.txt\n",
|
"# extra imports that might not be included in requirements.txt\n",
|
||||||
"import collections\n",
|
"import collections\n",
|
||||||
"import operator\n",
|
"import operator\n",
|
||||||
"\n"
|
"\n",
|
||||||
|
"%matplotlib inline"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -75,7 +76,7 @@
|
||||||
"CONFIG = load_config(CONFIG_FILE)\n",
|
"CONFIG = load_config(CONFIG_FILE)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Load some properties from config.json\n",
|
"# Load some properties from config.json\n",
|
||||||
"CONFIG_METADATA = sorted(load_tts_samples(CONFIG.datasets)[0])\n",
|
"CONFIG_METADATA = load_tts_samples(CONFIG.datasets)[0]\n",
|
||||||
"CONFIG_METADATA = CONFIG_METADATA\n",
|
"CONFIG_METADATA = CONFIG_METADATA\n",
|
||||||
"CONFIG_DATASET = CONFIG.datasets[0]\n",
|
"CONFIG_DATASET = CONFIG.datasets[0]\n",
|
||||||
"CONFIG_PHONEME_LANGUAGE = CONFIG.phoneme_language\n",
|
"CONFIG_PHONEME_LANGUAGE = CONFIG.phoneme_language\n",
|
||||||
|
@ -84,7 +85,10 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Will be printed on generated output graph\n",
|
"# Will be printed on generated output graph\n",
|
||||||
"CONFIG_RUN_NAME = CONFIG.run_name\n",
|
"CONFIG_RUN_NAME = CONFIG.run_name\n",
|
||||||
"CONFIG_RUN_DESC = CONFIG.run_description"
|
"CONFIG_RUN_DESC = CONFIG.run_description\n",
|
||||||
|
"\n",
|
||||||
|
"# Needed to convert text to phonemes and phonemes to ids\n",
|
||||||
|
"tokenizer, config = TTSTokenizer.init_from_config(CONFIG)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -112,12 +116,13 @@
|
||||||
"source": [
|
"source": [
|
||||||
"def get_phoneme_from_sequence(text):\n",
|
"def get_phoneme_from_sequence(text):\n",
|
||||||
" temp_list = []\n",
|
" temp_list = []\n",
|
||||||
" if len(text[0]) > 0:\n",
|
" if len(text[\"text\"]) > 0:\n",
|
||||||
" temp_text = text[0].rstrip('\\n')\n",
|
" #temp_text = text[0].rstrip('\\n')\n",
|
||||||
|
" temp_text = text[\"text\"].rstrip('\\n')\n",
|
||||||
" for rm_bad_chars in CHARS_TO_REMOVE:\n",
|
" for rm_bad_chars in CHARS_TO_REMOVE:\n",
|
||||||
" temp_text = temp_text.replace(rm_bad_chars,\"\")\n",
|
" temp_text = temp_text.replace(rm_bad_chars,\"\")\n",
|
||||||
" seq = phoneme_to_sequence(temp_text, [CONFIG_TEXT_CLEANER], CONFIG_PHONEME_LANGUAGE, CONFIG_ENABLE_EOS_BOS_CHARS)\n",
|
" seq = tokenizer.text_to_ids(temp_text)\n",
|
||||||
" text = sequence_to_phoneme(seq)\n",
|
" text = tokenizer.ids_to_text(seq)\n",
|
||||||
" text = text.replace(\" \",\"\")\n",
|
" text = text.replace(\" \",\"\")\n",
|
||||||
" temp_list.append(text)\n",
|
" temp_list.append(text)\n",
|
||||||
" return temp_list"
|
" return temp_list"
|
||||||
|
@ -229,7 +234,7 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -243,7 +248,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.5"
|
"version": "3.9.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
@ -48,7 +48,6 @@ config = TacotronConfig(
|
||||||
precompute_num_workers=24,
|
precompute_num_workers=24,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=5,
|
test_delay_epochs=5,
|
||||||
ga_alpha=0.0,
|
|
||||||
r=2,
|
r=2,
|
||||||
optimizer="CapacitronOptimizer",
|
optimizer="CapacitronOptimizer",
|
||||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||||
|
@ -68,16 +67,15 @@ config = TacotronConfig(
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
lr_scheduler="StepwiseGradualLR",
|
lr_scheduler="StepwiseGradualLR",
|
||||||
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e5, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
|
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e4, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
|
||||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||||
# Need to experiment with these below for capacitron
|
|
||||||
loss_masking=False,
|
loss_masking=False,
|
||||||
decoder_loss_alpha=1.0,
|
decoder_loss_alpha=1.0,
|
||||||
postnet_loss_alpha=1.0,
|
postnet_loss_alpha=1.0,
|
||||||
postnet_diff_spec_alpha=0.0,
|
postnet_diff_spec_alpha=1.0,
|
||||||
decoder_diff_spec_alpha=0.0,
|
decoder_diff_spec_alpha=1.0,
|
||||||
decoder_ssim_alpha=0.0,
|
decoder_ssim_alpha=1.0,
|
||||||
postnet_ssim_alpha=0.0,
|
postnet_ssim_alpha=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
ap = AudioProcessor(**config.audio.to_dict())
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
|
@ -52,7 +52,6 @@ config = Tacotron2Config(
|
||||||
precompute_num_workers=24,
|
precompute_num_workers=24,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=5,
|
test_delay_epochs=5,
|
||||||
ga_alpha=0.0,
|
|
||||||
r=2,
|
r=2,
|
||||||
optimizer="CapacitronOptimizer",
|
optimizer="CapacitronOptimizer",
|
||||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||||
|
@ -77,23 +76,20 @@ config = Tacotron2Config(
|
||||||
"gradual_learning_rates": [
|
"gradual_learning_rates": [
|
||||||
[0, 1e-3],
|
[0, 1e-3],
|
||||||
[2e4, 5e-4],
|
[2e4, 5e-4],
|
||||||
[4e5, 3e-4],
|
[4e4, 3e-4],
|
||||||
[6e4, 1e-4],
|
[6e4, 1e-4],
|
||||||
[8e4, 5e-5],
|
[8e4, 5e-5],
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||||
# dashboard_logger='wandb',
|
|
||||||
# sort_by_audio_len=True,
|
|
||||||
seq_len_norm=True,
|
seq_len_norm=True,
|
||||||
# Need to experiment with these below for capacitron
|
|
||||||
loss_masking=False,
|
loss_masking=False,
|
||||||
decoder_loss_alpha=1.0,
|
decoder_loss_alpha=1.0,
|
||||||
postnet_loss_alpha=1.0,
|
postnet_loss_alpha=1.0,
|
||||||
postnet_diff_spec_alpha=0.0,
|
postnet_diff_spec_alpha=1.0,
|
||||||
decoder_diff_spec_alpha=0.0,
|
decoder_diff_spec_alpha=1.0,
|
||||||
decoder_ssim_alpha=0.0,
|
decoder_ssim_alpha=1.0,
|
||||||
postnet_ssim_alpha=0.0,
|
postnet_ssim_alpha=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
ap = AudioProcessor(**config.audio.to_dict())
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
|
@ -54,7 +54,6 @@ config = FastPitchConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -53,7 +53,6 @@ config = FastSpeechConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
|
||||||
print_step=50,
|
print_step=50,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -68,7 +68,6 @@ config = Tacotron2Config(
|
||||||
print_step=25,
|
print_step=25,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
seq_len_norm=True,
|
seq_len_norm=True,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
dataset_config = BaseDatasetConfig(
|
dataset_config = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
)
|
)
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||||
win_length=1024,
|
|
||||||
hop_length=256,
|
|
||||||
num_mels=80,
|
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=45,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
@ -37,7 +23,7 @@ config = VitsConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
eval_batch_size=16,
|
eval_batch_size=16,
|
||||||
batch_group_size=5,
|
batch_group_size=5,
|
||||||
num_loader_workers=0,
|
num_loader_workers=8,
|
||||||
num_eval_loader_workers=4,
|
num_eval_loader_workers=4,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=-1,
|
||||||
|
@ -52,6 +38,7 @@ config = VitsConfig(
|
||||||
mixed_precision=True,
|
mixed_precision=True,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
cudnn_benchmark=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# INITIALIZE THE AUDIO PROCESSOR
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
|
|
|
@ -3,11 +3,10 @@ from glob import glob
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
|
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
@ -22,22 +21,13 @@ dataset_config = [
|
||||||
for path in dataset_paths
|
for path in dataset_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
win_length=1024,
|
win_length=1024,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
num_mels=80,
|
num_mels=80,
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=False,
|
|
||||||
trim_db=23.0,
|
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=True,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
resample=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vitsArgs = VitsArgs(
|
vitsArgs = VitsArgs(
|
||||||
|
@ -69,7 +59,6 @@ config = VitsConfig(
|
||||||
use_language_weighted_sampler=True,
|
use_language_weighted_sampler=True,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
|
||||||
min_audio_len=32 * 256 * 4,
|
min_audio_len=32 * 256 * 4,
|
||||||
max_audio_len=160000,
|
max_audio_len=160000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
|
|
@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
|
||||||
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
||||||
"Vor dem 22. November 1963.",
|
"Vor dem 22. November 1963.",
|
||||||
],
|
],
|
||||||
sort_by_audio_len=True,
|
|
||||||
max_seq_len=500000,
|
max_seq_len=500000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits
|
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.downloaders import download_thorsten_de
|
from TTS.utils.downloaders import download_thorsten_de
|
||||||
|
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
|
||||||
print("Downloading dataset")
|
print("Downloading dataset")
|
||||||
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
win_length=1024,
|
win_length=1024,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
num_mels=80,
|
num_mels=80,
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=45,
|
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=None,
|
mel_fmax=None,
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# take the scripts's parent's directory to prefix all the output paths.
|
# take the scripts's parent's directory to prefix all the output paths.
|
||||||
RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||||
echo $RUN_DIR
|
echo $RUN_DIR
|
||||||
# download LJSpeech dataset
|
# download VCTK dataset
|
||||||
wget https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip -O VCTK-Corpus-0.92.zip
|
wget https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip -O VCTK-Corpus-0.92.zip
|
||||||
# extract
|
# extract
|
||||||
mkdir VCTK
|
mkdir VCTK
|
||||||
|
|
|
@ -2,11 +2,10 @@ import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseAudioConfig
|
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = VitsAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||||
win_length=1024,
|
|
||||||
hop_length=256,
|
|
||||||
num_mels=80,
|
|
||||||
preemphasis=0.0,
|
|
||||||
ref_level_db=20,
|
|
||||||
log_func="np.log",
|
|
||||||
do_trim_silence=True,
|
|
||||||
trim_db=23.0,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
spec_gain=1.0,
|
|
||||||
signal_norm=False,
|
|
||||||
do_amp_to_db_linear=False,
|
|
||||||
resample=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vitsArgs = VitsArgs(
|
vitsArgs = VitsArgs(
|
||||||
|
@ -62,6 +47,7 @@ config = VitsConfig(
|
||||||
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
|
cudnn_benchmark=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# INITIALIZE THE AUDIO PROCESSOR
|
# INITIALIZE THE AUDIO PROCESSOR
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
# core deps
|
# core deps
|
||||||
numpy==1.21.6
|
numpy==1.21.6;python_version<"3.10"
|
||||||
|
numpy==1.22.4;python_version=="3.10"
|
||||||
cython==0.29.28
|
cython==0.29.28
|
||||||
scipy>=1.4.0
|
scipy>=1.4.0
|
||||||
torch>=1.7
|
torch>=1.7
|
||||||
torchaudio
|
torchaudio
|
||||||
soundfile
|
soundfile
|
||||||
librosa==0.8.0
|
librosa==0.8.0
|
||||||
numba==0.55.1
|
numba==0.55.1;python_version<"3.10"
|
||||||
inflect
|
numba==0.55.2;python_version=="3.10"
|
||||||
|
inflect==5.6.0
|
||||||
tqdm
|
tqdm
|
||||||
anyascii
|
anyascii
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|
|
@ -4,5 +4,4 @@ TF_CPP_MIN_LOG_LEVEL=3
|
||||||
# runtime bash based tests
|
# runtime bash based tests
|
||||||
# TODO: move these to python
|
# TODO: move these to python
|
||||||
./tests/bash_tests/test_demo_server.sh && \
|
./tests/bash_tests/test_demo_server.sh && \
|
||||||
./tests/bash_tests/test_resample.sh && \
|
|
||||||
./tests/bash_tests/test_compute_statistics.sh
|
./tests/bash_tests/test_compute_statistics.sh
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -90,7 +90,7 @@ setup(
|
||||||
# ext_modules=find_cython_extensions(),
|
# ext_modules=find_cython_extensions(),
|
||||||
# package
|
# package
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
packages=find_packages(include=["TTS*"]),
|
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
|
||||||
package_data={
|
package_data={
|
||||||
"TTS": [
|
"TTS": [
|
||||||
"VERSION",
|
"VERSION",
|
||||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
||||||
|
|
||||||
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
from TTS.config import BaseAudioConfig
|
from TTS.config import BaseAudioConfig
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
|
from TTS.utils.audio import numpy_transforms as np_transforms
|
||||||
|
|
||||||
|
TESTS_PATH = get_tests_path()
|
||||||
|
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||||
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
|
os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumpyTransforms(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
@dataclass
|
||||||
|
class AudioConfig(Coqpit):
|
||||||
|
sample_rate: int = 22050
|
||||||
|
fft_size: int = 1024
|
||||||
|
num_mels: int = 256
|
||||||
|
mel_fmax: int = 1800
|
||||||
|
mel_fmin: int = 0
|
||||||
|
hop_length: int = 256
|
||||||
|
win_length: int = 1024
|
||||||
|
pitch_fmax: int = 450
|
||||||
|
trim_db: int = -1
|
||||||
|
min_silence_sec: float = 0.01
|
||||||
|
gain: float = 1.0
|
||||||
|
base: float = 10.0
|
||||||
|
|
||||||
|
self.config = AudioConfig()
|
||||||
|
self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate)
|
||||||
|
|
||||||
|
def test_build_mel_basis(self):
|
||||||
|
"""Check if the mel basis is correctly built"""
|
||||||
|
print(" > Testing mel basis building.")
|
||||||
|
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||||
|
self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1))
|
||||||
|
|
||||||
|
def test_millisec_to_length(self):
|
||||||
|
"""Check if the conversion from milliseconds to length is correct"""
|
||||||
|
print(" > Testing millisec to length conversion.")
|
||||||
|
win_len, hop_len = np_transforms.millisec_to_length(
|
||||||
|
frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate
|
||||||
|
)
|
||||||
|
self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate))
|
||||||
|
self.assertEqual(win_len, self.config.sample_rate)
|
||||||
|
|
||||||
|
def test_amplitude_db_conversion(self):
|
||||||
|
di = np.random.rand(11)
|
||||||
|
o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10)
|
||||||
|
o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10)
|
||||||
|
np.testing.assert_almost_equal(di, o2, decimal=5)
|
||||||
|
|
||||||
|
def test_preemphasis_deemphasis(self):
|
||||||
|
di = np.random.rand(11)
|
||||||
|
o1 = np_transforms.preemphasis(x=di, coeff=0.95)
|
||||||
|
o2 = np_transforms.deemphasis(x=o1, coeff=0.95)
|
||||||
|
np.testing.assert_almost_equal(di, o2, decimal=5)
|
||||||
|
|
||||||
|
def test_spec_to_mel(self):
|
||||||
|
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||||
|
spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T]
|
||||||
|
mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis)
|
||||||
|
self.assertEqual(mel.shape, (self.config.num_mels, 20))
|
||||||
|
|
||||||
|
def mel_to_spec(self):
|
||||||
|
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||||
|
mel = np.random.rand(self.config.num_mels, 20) # [C, T]
|
||||||
|
spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis)
|
||||||
|
self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20))
|
||||||
|
|
||||||
|
def test_wav_to_spec(self):
|
||||||
|
spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config)
|
||||||
|
self.assertEqual(
|
||||||
|
spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wav_to_mel(self):
|
||||||
|
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||||
|
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
|
||||||
|
self.assertEqual(
|
||||||
|
mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compute_f0(self):
|
||||||
|
pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config)
|
||||||
|
mel_basis = np_transforms.build_mel_basis(**self.config)
|
||||||
|
mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config)
|
||||||
|
assert pitch.shape[0] == mel.shape[1]
|
||||||
|
|
||||||
|
def test_load_wav(self):
|
||||||
|
wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050)
|
||||||
|
wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000)
|
||||||
|
self.assertEqual(wav.shape, (self.sample_wav.shape[0],))
|
||||||
|
self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],))
|
|
@ -1,29 +0,0 @@
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tests import get_tests_input_path, get_tests_output_path, run_cli
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
class TestRemoveSilenceVAD(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def test():
|
|
||||||
# set paths
|
|
||||||
wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs")
|
|
||||||
output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/")
|
|
||||||
output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/")
|
|
||||||
|
|
||||||
# resample audios
|
|
||||||
run_cli(
|
|
||||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000'
|
|
||||||
)
|
|
||||||
|
|
||||||
# run test
|
|
||||||
run_cli(
|
|
||||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"'
|
|
||||||
)
|
|
||||||
run_cli(f'rm -rf "{output_resample_path}"')
|
|
||||||
run_cli(f'rm -rf "{output_path}"')
|
|
|
@ -1,16 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
set -xe
|
|
||||||
BASEDIR=$(dirname "$0")
|
|
||||||
TARGET_SR=16000
|
|
||||||
echo "$BASEDIR"
|
|
||||||
#run the resample script
|
|
||||||
python TTS/bin/resample.py --input_dir $BASEDIR/../data/ljspeech --output_dir $BASEDIR/outputs/resample_tests --output_sr $TARGET_SR
|
|
||||||
#check samplerate of output
|
|
||||||
OUT_SR=$( (echo "import librosa" ; echo "y, sr = librosa.load('"$BASEDIR"/outputs/resample_tests/wavs/LJ001-0012.wav', sr=None)" ; echo "print(sr)") | python )
|
|
||||||
OUT_SR=$(($OUT_SR + 0))
|
|
||||||
if [[ $OUT_SR -ne $TARGET_SR ]]; then
|
|
||||||
echo "Missmatch between target and output sample rates"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
#cleaning up
|
|
||||||
rm -rf $BASEDIR/outputs/resample_tests
|
|
|
@ -5,11 +5,11 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.data import get_length_balancer_weights
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||||
|
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
len2 += 1
|
len2 += 1
|
||||||
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
||||||
|
|
||||||
|
def test_bucket_batch_sampler(self):
|
||||||
|
bucket_size_multiplier = 2
|
||||||
|
sampler = range(len(train_samples))
|
||||||
|
sampler = BucketBatchSampler(
|
||||||
|
sampler,
|
||||||
|
data=train_samples,
|
||||||
|
batch_size=7,
|
||||||
|
drop_last=True,
|
||||||
|
sort_key=lambda x: len(x["text"]),
|
||||||
|
bucket_size_multiplier=bucket_size_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the samples are sorted by text lenght whuile bucketing
|
||||||
|
min_text_len_in_bucket = 0
|
||||||
|
bucket_items = []
|
||||||
|
for batch_idx, batch in enumerate(list(sampler)):
|
||||||
|
if (batch_idx + 1) % bucket_size_multiplier == 0:
|
||||||
|
for bucket_item in bucket_items:
|
||||||
|
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
|
||||||
|
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
|
||||||
|
min_text_len_in_bucket = 0
|
||||||
|
bucket_items = []
|
||||||
|
else:
|
||||||
|
bucket_items += batch
|
||||||
|
|
||||||
|
# check sampler length
|
||||||
|
self.assertEqual(len(sampler), len(train_samples) // 7)
|
||||||
|
|
|
@ -30,6 +30,13 @@ class TestTTSTokenizer(unittest.TestCase):
|
||||||
test_hat = self.tokenizer_ph.ids_to_text(ids)
|
test_hat = self.tokenizer_ph.ids_to_text(ids)
|
||||||
self.assertEqual(text_ph, test_hat)
|
self.assertEqual(text_ph, test_hat)
|
||||||
|
|
||||||
|
def test_text_to_ids_phonemes_punctuation(self):
|
||||||
|
text = "..."
|
||||||
|
text_ph = self.ph.phonemize(text, separator="")
|
||||||
|
ids = self.tokenizer_ph.text_to_ids(text)
|
||||||
|
test_hat = self.tokenizer_ph.ids_to_text(ids)
|
||||||
|
self.assertEqual(text_ph, test_hat)
|
||||||
|
|
||||||
def test_text_to_ids_phonemes_with_eos_bos(self):
|
def test_text_to_ids_phonemes_with_eos_bos(self):
|
||||||
text = "Bu bir Örnek."
|
text = "Bu bir Örnek."
|
||||||
self.tokenizer_ph.use_eos_bos = True
|
self.tokenizer_ph.use_eos_bos = True
|
||||||
|
|
|
@ -0,0 +1,239 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch as T
|
||||||
|
|
||||||
|
from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss
|
||||||
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
# test input == target
|
||||||
|
layer = L1LossMasked(seq_len_norm=False)
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# seq_len_norm = True
|
||||||
|
# test input == target
|
||||||
|
layer = L1LossMasked(seq_len_norm=True)
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
|
class MSELossMaskedTests(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
# test input == target
|
||||||
|
layer = MSELossMasked(seq_len_norm=False)
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# seq_len_norm = True
|
||||||
|
# test input == target
|
||||||
|
layer = MSELossMasked(seq_len_norm=True)
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
|
class SSIMLossTests(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
# test input == target
|
||||||
|
layer = SSIMLoss()
|
||||||
|
dummy_input = T.ones(4, 57, 128).float()
|
||||||
|
dummy_target = T.ones(4, 57, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.arange(0, 4 * 57 * 128)
|
||||||
|
dummy_input = dummy_input.reshape(4, 57, 128).float()
|
||||||
|
dummy_target = T.arange(-4 * 57 * 128, 0)
|
||||||
|
dummy_target = dummy_target.reshape(4, 57, 128).float()
|
||||||
|
dummy_target = -dummy_target
|
||||||
|
|
||||||
|
dummy_length = (T.ones(4) * 58).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() >= 1.0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 57, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 57, 128).float()
|
||||||
|
dummy_length = (T.arange(54, 58)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 57, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(54, 58)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# seq_len_norm = True
|
||||||
|
# test input == target
|
||||||
|
layer = L1LossMasked(seq_len_norm=True)
|
||||||
|
dummy_input = T.ones(4, 57, 128).float()
|
||||||
|
dummy_target = T.ones(4, 57, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0.0
|
||||||
|
|
||||||
|
# test input != target
|
||||||
|
dummy_input = T.ones(4, 57, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 57, 128).float()
|
||||||
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
# test if padded values of input makes any difference
|
||||||
|
dummy_input = T.ones(4, 57, 128).float()
|
||||||
|
dummy_target = T.zeros(4, 57, 128).float()
|
||||||
|
dummy_length = (T.arange(54, 58)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
dummy_input = T.rand(4, 57, 128).float()
|
||||||
|
dummy_target = dummy_input.detach()
|
||||||
|
dummy_length = (T.arange(54, 58)).long()
|
||||||
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
|
class BCELossTest(unittest.TestCase):
|
||||||
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
|
layer = BCELossMasked(pos_weight=5.0)
|
||||||
|
|
||||||
|
length = T.tensor([95])
|
||||||
|
target = (
|
||||||
|
1.0 - sequence_mask(length - 1, 100).float()
|
||||||
|
) # [0, 0, .... 1, 1] where the first 1 is the last mel frame
|
||||||
|
true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target
|
||||||
|
zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding
|
||||||
|
early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping
|
||||||
|
late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping
|
||||||
|
|
||||||
|
loss = layer(true_x, target, length)
|
||||||
|
self.assertEqual(loss.item(), 0.0)
|
||||||
|
|
||||||
|
loss = layer(early_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 2.1053, places=4)
|
||||||
|
|
||||||
|
loss = layer(late_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||||
|
|
||||||
|
loss = layer(zero_x, target, length)
|
||||||
|
self.assertAlmostEqual(loss.item(), 5.2632, places=4)
|
||||||
|
|
||||||
|
# pos_weight should be < 1 to penalize early stopping
|
||||||
|
layer = BCELossMasked(pos_weight=0.2)
|
||||||
|
loss = layer(true_x, target, length)
|
||||||
|
self.assertEqual(loss.item(), 0.0)
|
||||||
|
|
||||||
|
# when pos_weight < 1 overweight the early stopping loss
|
||||||
|
|
||||||
|
loss_early = layer(early_x, target, length)
|
||||||
|
loss_late = layer(late_x, target, length)
|
||||||
|
self.assertGreater(loss_early.item(), loss_late.item())
|
|
@ -2,9 +2,7 @@ import unittest
|
||||||
|
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.tts.layers.losses import L1LossMasked, SSIMLoss
|
|
||||||
from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet
|
from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
@ -85,131 +83,3 @@ class EncoderTests(unittest.TestCase):
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 8
|
assert output.shape[1] == 8
|
||||||
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||||
|
|
||||||
|
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
|
||||||
def test_in_out(self): # pylint: disable=no-self-use
|
|
||||||
# test input == target
|
|
||||||
layer = L1LossMasked(seq_len_norm=False)
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.ones(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0.0
|
|
||||||
|
|
||||||
# test input != target
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# test if padded values of input makes any difference
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
|
||||||
dummy_target = dummy_input.detach()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# seq_len_norm = True
|
|
||||||
# test input == target
|
|
||||||
layer = L1LossMasked(seq_len_norm=True)
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.ones(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0.0
|
|
||||||
|
|
||||||
# test input != target
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# test if padded values of input makes any difference
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
|
||||||
dummy_target = dummy_input.detach()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
|
|
||||||
class SSIMLossTests(unittest.TestCase):
|
|
||||||
def test_in_out(self): # pylint: disable=no-self-use
|
|
||||||
# test input == target
|
|
||||||
layer = SSIMLoss()
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.ones(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0.0
|
|
||||||
|
|
||||||
# test input != target
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# test if padded values of input makes any difference
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
|
||||||
dummy_target = dummy_input.detach()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# seq_len_norm = True
|
|
||||||
# test input == target
|
|
||||||
layer = L1LossMasked(seq_len_norm=True)
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.ones(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0.0
|
|
||||||
|
|
||||||
# test input != target
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
# test if padded values of input makes any difference
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
|
||||||
dummy_target = dummy_input.detach()
|
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
|
||||||
|
|
|
@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
from TTS.tts.models.vits import (
|
||||||
|
Vits,
|
||||||
|
VitsArgs,
|
||||||
|
VitsAudioConfig,
|
||||||
|
amp_to_db,
|
||||||
|
db_to_amp,
|
||||||
|
load_audio,
|
||||||
|
spec_to_mel,
|
||||||
|
wav_to_mel,
|
||||||
|
wav_to_spec,
|
||||||
|
)
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
|
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
|
||||||
self._check_parameter_changes(model, model_ref)
|
self._check_parameter_changes(model, model_ref)
|
||||||
|
|
||||||
def test_train_step_upsampling(self):
|
def test_train_step_upsampling(self):
|
||||||
|
"""Upsampling by the decoder upsampling layers"""
|
||||||
# setup the model
|
# setup the model
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||||
model_args = VitsArgs(
|
model_args = VitsArgs(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
spec_segment_size=10,
|
spec_segment_size=10,
|
||||||
|
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
|
||||||
interpolate_z=False,
|
interpolate_z=False,
|
||||||
upsample_rates_decoder=[8, 8, 4, 2],
|
upsample_rates_decoder=[8, 8, 4, 2],
|
||||||
)
|
)
|
||||||
config = VitsConfig(model_args=model_args)
|
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
# model to train
|
# model to train
|
||||||
|
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
|
||||||
self._check_parameter_changes(model, model_ref)
|
self._check_parameter_changes(model, model_ref)
|
||||||
|
|
||||||
def test_train_step_upsampling_interpolation(self):
|
def test_train_step_upsampling_interpolation(self):
|
||||||
|
"""Upsampling by interpolation"""
|
||||||
# setup the model
|
# setup the model
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
|
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||||
config = VitsConfig(model_args=model_args)
|
model_args = VitsArgs(
|
||||||
|
num_chars=32,
|
||||||
|
spec_segment_size=10,
|
||||||
|
encoder_sample_rate=11025,
|
||||||
|
interpolate_z=True,
|
||||||
|
upsample_rates_decoder=[8, 8, 2, 2],
|
||||||
|
)
|
||||||
|
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
# model to train
|
# model to train
|
||||||
|
|
Loading…
Reference in New Issue