mirror of https://github.com/coqui-ai/TTS.git
commit
98a372bca2
|
@ -4,8 +4,8 @@ runs:
|
|||
using: 'composite'
|
||||
steps:
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: "0.5.1"
|
||||
version: "0.5.4"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "**/pyproject.toml"
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
name: integration
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
trainer_branch:
|
||||
description: "Branch of Trainer to test"
|
||||
required: false
|
||||
default: "main"
|
||||
coqpit_branch:
|
||||
description: "Branch of Coqpit to test"
|
||||
required: false
|
||||
default: "main"
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.12"]
|
||||
subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup uv
|
||||
uses: ./.github/actions/setup-uv
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install Espeak
|
||||
if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install espeak espeak-ng
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install custom Trainer and/or Coqpit if requested
|
||||
run: |
|
||||
if [[ -n "${{ github.event.inputs.trainer_branch }}" ]]; then
|
||||
uv add git+https://github.com/idiap/coqui-ai-Trainer --branch ${{ github.event.inputs.trainer_branch }}
|
||||
fi
|
||||
if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then
|
||||
uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }}
|
||||
fi
|
||||
- name: Integration tests
|
||||
run: |
|
||||
resolution=highest
|
||||
if [ "${{ matrix.python-version }}" == "3.9" ]; then
|
||||
resolution=lowest-direct
|
||||
fi
|
||||
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
|
||||
- name: Upload coverage data
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
include-hidden-files: true
|
||||
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
|
||||
path: .coverage.*
|
||||
if-no-files-found: ignore
|
||||
coverage:
|
||||
if: always()
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup uv
|
||||
uses: ./.github/actions/setup-uv
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: coverage-data-*
|
||||
merge-multiple: true
|
||||
- name: Combine coverage
|
||||
run: |
|
||||
uv python install
|
||||
uvx coverage combine
|
||||
uvx coverage html --skip-covered --skip-empty
|
||||
uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
|
@ -1,4 +1,4 @@
|
|||
name: unit
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -17,7 +17,7 @@ on:
|
|||
required: false
|
||||
default: "main"
|
||||
jobs:
|
||||
test:
|
||||
unit:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
@ -62,9 +62,54 @@ jobs:
|
|||
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
|
||||
path: .coverage.*
|
||||
if-no-files-found: ignore
|
||||
integration:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.12"]
|
||||
subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup uv
|
||||
uses: ./.github/actions/setup-uv
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install Espeak
|
||||
if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install espeak espeak-ng
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install custom Trainer and/or Coqpit if requested
|
||||
run: |
|
||||
if [[ -n "${{ github.event.inputs.trainer_branch }}" ]]; then
|
||||
uv add git+https://github.com/idiap/coqui-ai-Trainer --branch ${{ github.event.inputs.trainer_branch }}
|
||||
fi
|
||||
if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then
|
||||
uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }}
|
||||
fi
|
||||
- name: Integration tests
|
||||
run: |
|
||||
resolution=highest
|
||||
if [ "${{ matrix.python-version }}" == "3.9" ]; then
|
||||
resolution=lowest-direct
|
||||
fi
|
||||
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
|
||||
- name: Upload coverage data
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
include-hidden-files: true
|
||||
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
|
||||
path: .coverage.*
|
||||
if-no-files-found: ignore
|
||||
coverage:
|
||||
if: always()
|
||||
needs: test
|
||||
needs: [unit, integration]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
|
@ -80,7 +80,7 @@ Example run:
|
|||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(C)
|
||||
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||||
model, _ = load_checkpoint(model, args.model_path, use_cuda=args.use_cuda, eval=True)
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch
|
|||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from trainer.generic_utils import set_partial_state_dict
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -130,7 +130,7 @@ class BaseEncoder(nn.Module):
|
|||
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
model_dict = set_partial_state_dict(model_dict, state["model"], config)
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@ import os
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer import TrainerArgs
|
||||
from trainer.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from trainer.io import copy_model_files
|
||||
from trainer.io import copy_model_files, get_last_checkpoint
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
|
|
|
@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int:
|
|||
raise RuntimeError(msg) from e
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -14,6 +14,8 @@ from torch import nn
|
|||
from torchaudio.functional import resample
|
||||
from transformers import HubertModel
|
||||
|
||||
from TTS.utils.generic_utils import exists
|
||||
|
||||
|
||||
def round_down_nearest_multiple(num, divisor):
|
||||
return num // divisor * divisor
|
||||
|
@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False):
|
|||
return t[..., seq_slice]
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
class CustomHubert(nn.Module):
|
||||
"""
|
||||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
||||
|
|
|
@ -12,18 +12,6 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
||||
|
||||
def __init__(self, ndim, bias):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(ndim))
|
||||
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -119,9 +107,9 @@ class MLP(nn.Module):
|
|||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.mlp = MLP(config)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
|
@ -158,7 +146,7 @@ class GPT(nn.Module):
|
|||
wpe=nn.Embedding(config.block_size, config.n_embd),
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
|
||||
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
||||
ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
|
||||
)
|
||||
)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
|
||||
|
|
|
@ -12,7 +12,6 @@ from TTS.tts.layers.delightful_tts.conformer import Conformer
|
|||
from TTS.tts.layers.delightful_tts.encoders import (
|
||||
PhonemeLevelProsodyEncoder,
|
||||
UtteranceLevelProsodyEncoder,
|
||||
get_mask_from_lengths,
|
||||
)
|
||||
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor
|
||||
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding
|
||||
|
@ -20,7 +19,7 @@ from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProso
|
|||
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor
|
||||
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
||||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||
from TTS.tts.utils.helpers import generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -231,42 +230,6 @@ class AcousticModel(torch.nn.Module):
|
|||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the linear scale durations.
|
||||
|
||||
Args:
|
||||
dr (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def _expand_encoder_with_durations(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
dr: torch.IntTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
y_lengths: torch.IntTensor,
|
||||
):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en])
|
||||
return y_mask, o_en_ex, attn.transpose(1, 2)
|
||||
|
||||
def _forward_aligner(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
|
@ -340,8 +303,8 @@ class AcousticModel(torch.nn.Module):
|
|||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
) # pylint: disable=unused-variable
|
||||
|
||||
src_mask = get_mask_from_lengths(src_lens) # [B, T_src]
|
||||
mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel]
|
||||
src_mask = ~sequence_mask(src_lens) # [B, T_src]
|
||||
mel_mask = ~sequence_mask(mel_lens) # [B, T_mel]
|
||||
|
||||
# Token embeddings
|
||||
token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden]
|
||||
|
@ -420,8 +383,8 @@ class AcousticModel(torch.nn.Module):
|
|||
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb
|
||||
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask)
|
||||
|
||||
mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None]
|
||||
encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs(
|
||||
encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
x = self.decoder(
|
||||
|
@ -435,7 +398,7 @@ class AcousticModel(torch.nn.Module):
|
|||
dr = torch.log(dr + 1)
|
||||
|
||||
dr_pred = torch.exp(log_duration_prediction) - 1
|
||||
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
|
||||
alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
|
||||
|
||||
return {
|
||||
"model_outputs": x,
|
||||
|
@ -448,7 +411,7 @@ class AcousticModel(torch.nn.Module):
|
|||
"p_prosody_pred": p_prosody_pred,
|
||||
"p_prosody_ref": p_prosody_ref,
|
||||
"alignments_dp": alignments_dp,
|
||||
"alignments": alignments, # [B, T_de, T_en]
|
||||
"alignments": alignments.transpose(1, 2), # [B, T_de, T_en]
|
||||
"aligner_soft": aligner_soft,
|
||||
"aligner_mas": aligner_mas,
|
||||
"aligner_durations": aligner_durations,
|
||||
|
@ -469,7 +432,7 @@ class AcousticModel(torch.nn.Module):
|
|||
pitch_transform: Callable = None,
|
||||
energy_transform: Callable = None,
|
||||
) -> torch.Tensor:
|
||||
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
|
||||
src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
|
||||
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable
|
||||
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
|
||||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
|
@ -536,11 +499,11 @@ class AcousticModel(torch.nn.Module):
|
|||
duration_pred = torch.round(duration_pred) # -> [B, T_src]
|
||||
mel_lens = duration_pred.sum(1) # -> [B,]
|
||||
|
||||
_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
|
||||
encoder_outputs_ex, alignments, _ = expand_encoder_outputs(
|
||||
encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
mel_mask = get_mask_from_lengths(
|
||||
mel_mask = ~sequence_mask(
|
||||
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device)
|
||||
)
|
||||
|
||||
|
@ -557,7 +520,7 @@ class AcousticModel(torch.nn.Module):
|
|||
x = self.to_mel(x)
|
||||
outputs = {
|
||||
"model_outputs": x,
|
||||
"alignments": alignments,
|
||||
"alignments": alignments.transpose(1, 2),
|
||||
# "pitch": pitch_emb_pred,
|
||||
"durations": duration_pred,
|
||||
"pitch": pitch_pred,
|
||||
|
|
|
@ -1,20 +1,14 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d, calc_same_padding
|
||||
from TTS.tts.layers.delightful_tts.networks import GLUActivation
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
class Conformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -322,7 +316,7 @@ class ConformerMultiHeadedSelfAttention(nn.Module):
|
|||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable
|
||||
encoding = encoding[:, : key.shape[1]]
|
||||
encoding = encoding.repeat(batch_size, 1, 1)
|
||||
|
@ -378,7 +372,7 @@ class RelativeMultiHeadAttention(nn.Module):
|
|||
value: torch.Tensor,
|
||||
pos_embedding: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = query.shape[0]
|
||||
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
||||
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
||||
|
@ -411,40 +405,3 @@ class RelativeMultiHeadAttention(nn.Module):
|
|||
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
||||
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
||||
return pos_score
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
split_size = self.num_units // self.num_heads
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
return out
|
||||
|
|
|
@ -3,9 +3,6 @@ from typing import Tuple
|
|||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
|
@ -530,142 +527,3 @@ class CoordConv2d(nn.modules.conv.Conv2d):
|
|||
x = self.addcoords(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
stride,
|
||||
dilations=[1, 3, 9, 27],
|
||||
lReLU_slope=0.2,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = len(dilations)
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=len(dilations),
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
||||
)
|
||||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
2 * stride,
|
||||
stride=stride,
|
||||
padding=stride // 2 + stride % 2,
|
||||
output_padding=stride % 2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
for dilation in dilations:
|
||||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
conv_kernel_size,
|
||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
||||
dilation=dilation,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||
|
||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
for i, conv in enumerate(self.conv_blocks):
|
||||
output = conv(x) # (B, c_g, stride * L')
|
||||
|
||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||
|
||||
output = self.location_variable_convolution(
|
||||
output, k, b, hop_size=self.cond_hop_length
|
||||
) # (B, 2 * c_g, stride * L'): LVC
|
||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||
output[:, in_channels:, :]
|
||||
) # (B, c_g, stride * L'): GAU
|
||||
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||
for block in self.conv_blocks:
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
|
|
|
@ -7,14 +7,7 @@ import torch.nn.functional as F
|
|||
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
|
||||
from TTS.tts.layers.delightful_tts.networks import STL
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
return mask
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
|
@ -93,7 +86,7 @@ class ReferenceEncoder(nn.Module):
|
|||
outputs --- [N, E//2]
|
||||
"""
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
|
||||
mel_masks = ~sequence_mask(mel_lens).unsqueeze(1)
|
||||
x = x.masked_fill(mel_masks, 0)
|
||||
for conv, norm in zip(self.convs, self.norms):
|
||||
x = conv(x)
|
||||
|
@ -103,7 +96,7 @@ class ReferenceEncoder(nn.Module):
|
|||
for _ in range(2):
|
||||
mel_lens = stride_lens(mel_lens)
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens)
|
||||
mel_masks = ~sequence_mask(mel_lens)
|
||||
|
||||
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
|
||||
x = x.permute((0, 2, 1))
|
||||
|
|
|
@ -1,128 +0,0 @@
|
|||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
class KernelPredictor(nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions
|
||||
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int): number of layers
|
||||
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
for _ in range(3):
|
||||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
c = self.input_conv(c)
|
||||
for residual_conv in self.residual_convs:
|
||||
residual_conv.to(c.device)
|
||||
c = c + residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
kernels = k.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length,
|
||||
)
|
||||
bias = b.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length,
|
||||
)
|
||||
|
||||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||
parametrize.remove_parametrizations(self.bias_conv, "weight")
|
||||
for block in self.residual_convs:
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
parametrize.remove_parametrizations(block[3], "weight")
|
|
@ -309,6 +309,24 @@ class ForwardSumLoss(nn.Module):
|
|||
return total_loss
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
||||
|
||||
########################
|
||||
# MODEL LOSS LAYERS
|
||||
########################
|
||||
|
@ -619,6 +637,28 @@ class AlignTTSLoss(nn.Module):
|
|||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
class VitsGeneratorLoss(nn.Module):
|
||||
def __init__(self, c: Coqpit):
|
||||
super().__init__()
|
||||
|
@ -640,28 +680,6 @@ class VitsGeneratorLoss(nn.Module):
|
|||
do_amp_to_db=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
@staticmethod
|
||||
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
||||
"""
|
||||
|
@ -722,10 +740,8 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1))
|
||||
* self.kl_loss_alpha
|
||||
)
|
||||
loss_feat = (
|
||||
self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
)
|
||||
loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha
|
||||
loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
@ -779,6 +795,15 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments.
|
||||
|
||||
Explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
|
||||
class ForwardTTSLoss(nn.Module):
|
||||
"""Generic configurable ForwardTTS loss."""
|
||||
|
||||
|
@ -820,14 +845,6 @@ class ForwardTTSLoss(nn.Module):
|
|||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_output,
|
||||
|
@ -879,7 +896,7 @@ class ForwardTTSLoss(nn.Module):
|
|||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||
|
||||
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
if binary_loss_weight:
|
||||
return_dict["loss_binary_alignment"] = (
|
||||
|
|
|
@ -3,6 +3,8 @@ from torch import nn
|
|||
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
|
||||
|
||||
|
||||
class CapacitronVAE(nn.Module):
|
||||
"""Effective Use of Variational Embedding Capacity for prosody transfer.
|
||||
|
@ -97,7 +99,7 @@ class ReferenceEncoder(nn.Module):
|
|||
self.training = False
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
self.recurrence = nn.LSTM(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
|
||||
)
|
||||
|
@ -155,13 +157,6 @@ class ReferenceEncoder(nn.Module):
|
|||
|
||||
return last_output.to(inputs.device) # [B, 128]
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class TextSummary(nn.Module):
|
||||
def __init__(self, embedding_dim, encoder_output_dim):
|
||||
|
|
|
@ -3,6 +3,13 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def calculate_post_conv_height(height: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""Linear layer with a specific initialization.
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
|
||||
|
||||
|
||||
class GST(nn.Module):
|
||||
"""Global Style Token Module for factorizing prosody in speech.
|
||||
|
@ -44,7 +46,7 @@ class ReferenceEncoder(nn.Module):
|
|||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
|
||||
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
|
||||
self.recurrence = nn.GRU(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True
|
||||
)
|
||||
|
@ -71,13 +73,6 @@ class ReferenceEncoder(nn.Module):
|
|||
|
||||
return out.squeeze(0)
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
@ -117,7 +112,7 @@ class MultiHeadAttention(nn.Module):
|
|||
out --- [N, T_q, num_units]
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
|
@ -127,7 +122,7 @@ class MultiHeadAttention(nn.Module):
|
|||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
||||
queries = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
|
@ -137,13 +132,11 @@ class MultiHeadAttention(nn.Module):
|
|||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k**0.5))
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
||||
return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
|
|
@ -93,12 +93,10 @@ class AttentionBlock(nn.Module):
|
|||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
do_checkpoint=True,
|
||||
relative_pos_embeddings=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.do_checkpoint = do_checkpoint
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
|
@ -185,114 +183,6 @@ class Downsample(nn.Module):
|
|||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False)
|
||||
self.x_upd = Upsample(channels, False)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False)
|
||||
self.x_upd = Downsample(channels, False)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
|
||||
else:
|
||||
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AudioMiniEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
base_channels=128,
|
||||
depth=2,
|
||||
resnet_blocks=2,
|
||||
attn_blocks=4,
|
||||
num_attn_heads=4,
|
||||
dropout=0,
|
||||
downsample_factor=2,
|
||||
kernel_size=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
|
||||
ch = base_channels
|
||||
res = []
|
||||
for l in range(depth):
|
||||
for r in range(resnet_blocks):
|
||||
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
|
||||
ch *= 2
|
||||
self.res = nn.Sequential(*res)
|
||||
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(
|
||||
AttentionBlock(
|
||||
embedding_dim,
|
||||
num_attn_heads,
|
||||
)
|
||||
)
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
h = self.init(x)
|
||||
h = self.res(h)
|
||||
h = self.final(h)
|
||||
h = self.attn(h)
|
||||
return h[:, :, 0]
|
||||
|
||||
|
||||
DEFAULT_MEL_NORM_FILE = "https://github.com/coqui-ai/TTS/releases/download/v0.14.1_models/mel_norms.pth"
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torchaudio
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel):
|
|||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def get_voices(extra_voice_dirs: List[str] = []):
|
||||
dirs = extra_voice_dirs
|
||||
voices: Dict[str, List[str]] = {}
|
||||
|
@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
|
|||
)
|
||||
stft = stft.to(device)
|
||||
mel = stft(wav)
|
||||
mel = dynamic_range_compression(mel)
|
||||
mel = amp_to_db(mel)
|
||||
if do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
return mel
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# AGPL: a notification must be added stating that changes have been made to that file.
|
||||
import functools
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -123,7 +124,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - mel_len, attention_mask.device
|
||||
attention_mask.shape[1] - (mel_len + 1), attention_mask.device
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
|
@ -175,8 +176,6 @@ class ConditioningEncoder(nn.Module):
|
|||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
do_checkpointing=False,
|
||||
mean=False,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
@ -185,34 +184,46 @@ class ConditioningEncoder(nn.Module):
|
|||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.do_checkpointing = do_checkpointing
|
||||
self.mean = mean
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (b, 80, s)
|
||||
"""
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
if self.mean:
|
||||
return h.mean(dim=2)
|
||||
else:
|
||||
return h[:, :, 0]
|
||||
return h
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=0.02):
|
||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
def build_hf_gpt_transformer(
|
||||
layers: int,
|
||||
model_dim: int,
|
||||
heads: int,
|
||||
max_mel_seq_len: int,
|
||||
max_text_seq_len: int,
|
||||
checkpointing: bool,
|
||||
max_prompt_len: int = 0,
|
||||
):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
|
@ -220,8 +231,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len + max_text_seq_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
|
@ -234,13 +245,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
return (
|
||||
gpt,
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim),
|
||||
None,
|
||||
None,
|
||||
|
||||
mel_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
class MelEncoder(nn.Module):
|
||||
|
@ -334,12 +350,12 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_layer_pos_embedding,
|
||||
self.text_layer_pos_embedding,
|
||||
) = build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||
self.max_text_tokens + 2,
|
||||
checkpointing,
|
||||
layers=layers,
|
||||
model_dim=model_dim,
|
||||
heads=heads,
|
||||
max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
||||
max_text_seq_len=self.max_text_tokens + 2,
|
||||
checkpointing=checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
|
@ -455,7 +471,7 @@ class UnifiedVoice(nn.Module):
|
|||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
|
|
@ -16,7 +16,6 @@ class ResBlock(nn.Module):
|
|||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
do_checkpoint=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -24,7 +23,6 @@ class ResBlock(nn.Module):
|
|||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.do_checkpoint = do_checkpoint
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
|
@ -92,14 +90,14 @@ class AudioMiniEncoder(nn.Module):
|
|||
self.layers = depth
|
||||
for l in range(depth):
|
||||
for r in range(resnet_blocks):
|
||||
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
|
||||
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
|
||||
ch *= 2
|
||||
self.res = nn.Sequential(*res)
|
||||
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
|
|
@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer
|
|||
from TTS.tts.layers.tortoise.xtransformers import Encoder
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def masked_mean(t, mask, dim=1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.0)
|
||||
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
||||
|
|
|
@ -196,31 +196,26 @@ class DiffusionTts(nn.Module):
|
|||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
do_checkpoint=False,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
do_checkpoint=False,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
do_checkpoint=False,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
do_checkpoint=False,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
do_checkpoint=False,
|
||||
),
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
|
||||
|
|
|
@ -1,22 +1,19 @@
|
|||
from typing import TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.generic_utils import exists
|
||||
|
||||
# helpers
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth=1):
|
||||
def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
return tuple(val)
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import math
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value
|
||||
from TTS.utils.generic_utils import default, exists
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
|
||||
|
@ -25,20 +27,6 @@ LayerIntermediates = namedtuple(
|
|||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
class always:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
@ -63,10 +51,6 @@ class equals:
|
|||
return x == self.val
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module):
|
|||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = torch.nn.functional.leaky_relu(x, 0.1)
|
||||
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
|
|
|
@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import functools
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -8,83 +7,16 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation
|
||||
from TTS.tts.layers.tortoise.autoregressive import (
|
||||
ConditioningEncoder,
|
||||
LearnedPositionEmbeddings,
|
||||
_prepare_attention_mask_for_generation,
|
||||
build_hf_gpt_transformer,
|
||||
)
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
max_mel_seq_len,
|
||||
max_text_seq_len,
|
||||
max_prompt_len,
|
||||
checkpointing,
|
||||
):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
|
||||
mel_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -149,13 +81,13 @@ class GPT(nn.Module):
|
|||
self.mel_layer_pos_embedding,
|
||||
self.text_layer_pos_embedding,
|
||||
) = build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
self.max_mel_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
layers=layers,
|
||||
model_dim=model_dim,
|
||||
heads=heads,
|
||||
max_mel_seq_len=self.max_mel_tokens,
|
||||
max_text_seq_len=self.max_text_tokens,
|
||||
max_prompt_len=self.max_prompt_tokens,
|
||||
checkpointing=checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
|
@ -303,19 +235,6 @@ class GPT(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
"""
|
||||
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
|
||||
|
@ -354,6 +273,7 @@ class GPT(nn.Module):
|
|||
"""
|
||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||
conds: (b, 1024, s)
|
||||
output: (b, 1024, 32)
|
||||
"""
|
||||
conds = None
|
||||
if not return_latent:
|
||||
|
|
|
@ -1,618 +1,13 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input tensor.
|
||||
Returns:
|
||||
Tensor: output tensor.
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
"""
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
|
||||
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
||||
|
||||
Args:
|
||||
channels (int): number of hidden channels for the convolutional layers.
|
||||
kernel_size (int): size of the convolution filter in each layer.
|
||||
dilations (list): list of dilation value for each conv layer in a block.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
resblock_type,
|
||||
resblock_dilation_sizes,
|
||||
resblock_kernel_sizes,
|
||||
upsample_kernel_sizes,
|
||||
upsample_initial_channel,
|
||||
upsample_factors,
|
||||
inference_padding=5,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
cond_in_each_up_layer=False,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
Network:
|
||||
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||
.. -> zI ---|
|
||||
resblockN_kNx1 -> zN ---'
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||
for each consecutive upsampling layer.
|
||||
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
# MRF blocks
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.conds.append(nn.Conv1d(cond_channels, ch, 1))
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): feature input tensor.
|
||||
g (Tensor): global conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
o = self.conv_pre(x)
|
||||
if hasattr(self, "cond_layer"):
|
||||
o = o + self.cond_layer(g)
|
||||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
o = o + self.conds[i](g)
|
||||
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||
else:
|
||||
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||
o = z_sum / self.num_kernels
|
||||
o = F.leaky_relu(o)
|
||||
o = self.conv_post(o)
|
||||
o = torch.tanh(o)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): conditioning input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: output waveform.
|
||||
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
logger.info("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se = SELayer(planes, reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out)
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
logger.warning("Layer missing in the model definition: %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies."""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=64,
|
||||
proj_dim=512,
|
||||
layers=[3, 4, 6, 3],
|
||||
num_filters=[32, 64, 128, 256],
|
||||
encoder_type="ASP",
|
||||
log_input=False,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
||||
self.inplanes = num_filters[0]
|
||||
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
|
||||
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
|
||||
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
|
||||
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
outmap_size = int(self.input_dim / 8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError("Undefined encoder")
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def create_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def new_parameter(self, *size):
|
||||
out = nn.Parameter(torch.FloatTensor(*size))
|
||||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x, l2_norm=False):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
x.squeeze_(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
||||
if self.log_input:
|
||||
x = (x + 1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
|
||||
x = torch.cat((mu, sg), 1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False,
|
||||
use_cuda: bool = False,
|
||||
criterion=None,
|
||||
cache=False,
|
||||
):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
logger.info("Model fully restored.")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"])
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
logger.exception("Criterion load ignored because of: %s", error)
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
||||
|
||||
|
||||
class HifiDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
|
|
|
@ -6,10 +6,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
|
@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs):
|
|||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
groups = 32
|
||||
if channels <= 16:
|
||||
groups = 8
|
||||
elif channels <= 64:
|
||||
groups = 16
|
||||
while channels % groups != 0:
|
||||
groups = int(groups / 2)
|
||||
assert groups > 2
|
||||
return GroupNorm32(groups, channels)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
|
@ -114,28 +93,3 @@ class AttentionBlock(nn.Module):
|
|||
h = self.proj_out(h)
|
||||
xp = self.x_proj(x)
|
||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (b, 80, s)
|
||||
"""
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
return h
|
||||
|
|
|
@ -9,9 +9,8 @@ from einops import rearrange, repeat
|
|||
from einops.layers.torch import Rearrange
|
||||
from torch import einsum, nn
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
from TTS.tts.layers.tortoise.transformer import GEGLU
|
||||
from TTS.utils.generic_utils import default, exists
|
||||
|
||||
|
||||
def once(fn):
|
||||
|
@ -151,12 +150,6 @@ def Sequential(*mods):
|
|||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, scale=True, dim_cond=None):
|
||||
super().__init__()
|
||||
|
@ -194,12 +187,6 @@ class CausalConv1d(nn.Conv1d):
|
|||
return super().forward(causal_padded_x)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.gelu(gate) * x
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, causal_conv=False):
|
||||
dim_inner = int(dim * mult * 2 / 3)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from spacy.lang.zh import Chinese
|
|||
from tokenizers import Tokenizer
|
||||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
from TTS.tts.utils.text.cleaners import collapse_whitespace, lowercase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -72,8 +73,6 @@ def split_sentence(text, lang, text_split_length=250):
|
|||
return text_splits
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
|
@ -564,14 +563,6 @@ def expand_numbers_multilingual(text, lang="en"):
|
|||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
|
@ -586,13 +577,6 @@ def multilingual_cleaners(text, lang):
|
|||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_transliterate(text):
|
||||
try:
|
||||
import pypinyin
|
||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -169,35 +169,6 @@ class AlignTTS(BaseTTS):
|
|||
dr_mas = torch.sum(attn, -1)
|
||||
return dr_mas.squeeze(1), log_p
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Examples::
|
||||
- encoder output: [a,b,c,d]
|
||||
- durations: [1, 3, 2, 1]
|
||||
|
||||
- expanded: [a, b, b, b, c, c, d]
|
||||
- attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
||||
o_dr[o_dr < 1] = 1.0
|
||||
|
@ -243,9 +214,8 @@ class AlignTTS(BaseTTS):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
|
||||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
@ -282,7 +252,7 @@ class AlignTTS(BaseTTS):
|
|||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
attn = self.generate_attn(dr_mas, x_mask, y_mask)
|
||||
attn = generate_attention(dr_mas, x_mask, y_mask)
|
||||
elif phase == 1:
|
||||
# train decoder
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
|
|
|
@ -8,30 +8,36 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.io import load_fsspec
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
|
||||
from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||
from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss
|
||||
from TTS.tts.layers.losses import (
|
||||
ForwardSumLoss,
|
||||
VitsDiscriminatorLoss,
|
||||
_binary_alignment_loss,
|
||||
feature_loss,
|
||||
generator_loss,
|
||||
)
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.models.base_tts import BaseTTSE2E
|
||||
from TTS.tts.models.vits import load_audio
|
||||
from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import embedding_to_torch, id_to_torch, numpy_to_torch
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram
|
||||
from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec
|
||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -39,284 +45,20 @@ from TTS.vocoder.utils.generic_utils import plot_results
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return aux_id.cuda()
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).float()
|
||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||
if cuda:
|
||||
return d_vector.cuda()
|
||||
return d_vector
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
||||
if cuda:
|
||||
return tensor.cuda()
|
||||
return tensor
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
return mask
|
||||
|
||||
|
||||
def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
|
||||
out_list = torch.jit.annotate(List[torch.Tensor], [])
|
||||
for batch in input_ele:
|
||||
if len(batch.shape) == 1:
|
||||
one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0)
|
||||
else:
|
||||
one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0)
|
||||
out_list.append(one_batch_padded)
|
||||
out_padded = torch.stack(out_list)
|
||||
return out_padded
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
return torch.ceil(lens / stride).int()
|
||||
|
||||
|
||||
def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor:
|
||||
assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..."
|
||||
return torch.randn(shape) * np.sqrt(2 / shape[1])
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
# check if the current module has reset_parameters and if it is reset the weight
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
|
||||
|
||||
def get_module_weights_sum(mdl: nn.Module):
|
||||
dict_sums = {}
|
||||
for name, w in mdl.named_parameters():
|
||||
if "weight" in name:
|
||||
value = w.data.sum().item()
|
||||
dict_sums[name] = value
|
||||
return dict_sums
|
||||
|
||||
|
||||
def load_audio(file_path: str):
|
||||
"""Load the audio file normalized in [-1, 1]
|
||||
|
||||
Return Shapes:
|
||||
- x: :math:`[1, T]`
|
||||
"""
|
||||
x, sr = torchaudio.load(
|
||||
file_path,
|
||||
)
|
||||
assert (x > 1).sum() + (x < -1).sum() == 0
|
||||
return x, sr
|
||||
|
||||
|
||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def _db_to_amp(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def amp_to_db(magnitudes):
|
||||
output = _amp_to_db(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def db_to_amp(magnitudes):
|
||||
output = _db_to_amp(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window # pylint: disable=global-statement
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
|
||||
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
return torch.norm(spec, dim=1, keepdim=True)
|
||||
|
||||
|
||||
def name_mel_basis(spec, n_fft, fmax):
|
||||
n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
||||
return n_fft_len
|
||||
|
||||
|
||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis # pylint: disable=global-statement
|
||||
mel_basis_key = name_mel_basis(spec, n_fft, fmax)
|
||||
# pylint: disable=too-many-function-args
|
||||
if mel_basis_key not in mel_basis:
|
||||
# pylint: disable=missing-kwoa
|
||||
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[mel_basis_key], spec)
|
||||
mel = amp_to_db(mel)
|
||||
return mel
|
||||
|
||||
|
||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T_y]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T_spec]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
mel_basis_key = name_mel_basis(y, n_fft, fmax)
|
||||
wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device)
|
||||
if mel_basis_key not in mel_basis:
|
||||
# pylint: disable=missing-kwoa
|
||||
mel = librosa_mel_fn(
|
||||
sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
) # pylint: disable=too-many-function-args
|
||||
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = torch.matmul(mel_basis[mel_basis_key], spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create balancer weight for torch WeightedSampler"""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class ForwardTTSE2eF0Dataset(F0Dataset):
|
||||
"""Override F0Dataset to avoid slow computing of pitches"""
|
||||
|
||||
|
@ -1196,7 +938,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
# TODO: add cloning support with ref_waveform
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
|
@ -1210,14 +952,14 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
if isinstance(speaker_id, str) and self.args.use_speaker_embedding:
|
||||
# get the speaker id for the speaker embedding layer
|
||||
_speaker_id = self.speaker_manager.name_to_id[speaker_id]
|
||||
_speaker_id = id_to_torch(_speaker_id, cuda=is_cuda)
|
||||
_speaker_id = id_to_torch(_speaker_id, device=device)
|
||||
|
||||
if speaker_id is not None and self.args.use_d_vector_file:
|
||||
# get the average d_vector for the speaker
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False)
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
|
@ -1240,7 +982,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
return return_dict
|
||||
|
||||
def synthesize_with_gl(self, text: str, speaker_id, d_vector):
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
|
@ -1249,12 +991,12 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
)
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, device=device)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
|
||||
# synthesize voice
|
||||
|
@ -1601,36 +1343,6 @@ class DelightfulTTSLoss(nn.Module):
|
|||
self.gen_loss_alpha = config.gen_loss_alpha
|
||||
self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha
|
||||
|
||||
@staticmethod
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
@staticmethod
|
||||
def feature_loss(feats_real, feats_generated):
|
||||
loss = 0
|
||||
for dr, dg in zip(feats_real, feats_generated):
|
||||
for rl, gl in zip(dr, dg):
|
||||
rl = rl.float().detach()
|
||||
gl = gl.float()
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
return loss * 2
|
||||
|
||||
@staticmethod
|
||||
def generator_loss(scores_fake):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in scores_fake:
|
||||
dg = dg.float()
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mel_output,
|
||||
|
@ -1728,7 +1440,7 @@ class DelightfulTTSLoss(nn.Module):
|
|||
)
|
||||
|
||||
if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None:
|
||||
binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft)
|
||||
binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft)
|
||||
total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight
|
||||
if binary_loss_weight:
|
||||
loss_dict["loss_binary_alignment"] = (
|
||||
|
@ -1748,8 +1460,8 @@ class DelightfulTTSLoss(nn.Module):
|
|||
|
||||
# vocoder losses
|
||||
if not skip_disc:
|
||||
loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||
loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||
loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha
|
||||
loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha
|
||||
loss_dict["vocoder_loss_feat"] = loss_feat
|
||||
loss_dict["vocoder_loss_gen"] = loss_gen
|
||||
loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen
|
||||
|
|
|
@ -14,7 +14,7 @@ from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
|||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask
|
||||
from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||
|
@ -310,49 +310,6 @@ class ForwardTTS(BaseTTS):
|
|||
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the durations.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Shapes:
|
||||
- en: :math:`(B, D_{en}, T_{en})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
|
||||
Examples::
|
||||
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
"""Format predicted durations.
|
||||
1. Convert to linear scale from log scale
|
||||
|
@ -443,9 +400,8 @@ class ForwardTTS(BaseTTS):
|
|||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||
"""
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
|
||||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
|
@ -624,7 +580,7 @@ class ForwardTTS(BaseTTS):
|
|||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
# generate attn mask from predicted durations
|
||||
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||
o_attn = generate_attention(o_dr.squeeze(1), x_mask)
|
||||
# aligner
|
||||
o_alignment_dur = None
|
||||
alignment_soft = None
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
|||
from trainer.io import load_fsspec
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.layers.losses import NLLLoss
|
||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||
from TTS.tts.layers.overflow.plotting_utils import (
|
||||
|
@ -373,21 +374,3 @@ class NeuralhmmTTS(BaseTTS):
|
|||
) -> None:
|
||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs[0])
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch import nn
|
|||
from trainer.io import load_fsspec
|
||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.layers.losses import NLLLoss
|
||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||
from TTS.tts.layers.overflow.decoder import Decoder
|
||||
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
|
||||
|
@ -389,21 +390,3 @@ class Overflow(BaseTTS):
|
|||
) -> None:
|
||||
logger.test_audios(steps, outputs[1], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs[0])
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
"""Negative log likelihood loss."""
|
||||
|
||||
def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
logits (Tensor): [B, T, D]
|
||||
|
||||
Returns:
|
||||
Tensor: [1]
|
||||
|
||||
"""
|
||||
return_dict = {}
|
||||
return_dict["loss"] = -log_prob.mean()
|
||||
return return_dict
|
||||
|
|
|
@ -10,7 +10,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from monotonic_alignment_search import maximum_path
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -21,7 +20,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
|
@ -35,6 +34,7 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.utils.samplers import BucketBatchSampler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -45,10 +45,6 @@ logger = logging.getLogger(__name__)
|
|||
# IO / Feature extraction
|
||||
##############################
|
||||
|
||||
# pylint: disable=global-statement
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
|
@ -78,143 +74,6 @@ def load_audio(file_path):
|
|||
return x, sr
|
||||
|
||||
|
||||
def _amp_to_db(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def _db_to_amp(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def amp_to_db(magnitudes):
|
||||
output = _amp_to_db(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def db_to_amp(magnitudes):
|
||||
output = _db_to_amp(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
mel = amp_to_db(mel)
|
||||
return mel
|
||||
|
||||
|
||||
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = amp_to_db(spec)
|
||||
return spec
|
||||
|
||||
|
||||
#############################
|
||||
# CONFIGS
|
||||
#############################
|
||||
|
@ -236,30 +95,6 @@ class VitsAudioConfig(Coqpit):
|
|||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
|
@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate):
|
|||
return audio
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): The input tensor to be padded or truncated.
|
||||
length (int): The desired length of the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The padded or truncated tensor.
|
||||
"""
|
||||
tp = t[..., :length]
|
||||
if t.shape[-1] == length:
|
||||
tp = t
|
||||
elif t.shape[-1] < length:
|
||||
tp = F.pad(t, (0, length - t.shape[-1]))
|
||||
return tp
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(Coqpit):
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import betabinom
|
||||
|
@ -33,7 +35,7 @@ class StandardScaler:
|
|||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
def sequence_mask(sequence_length: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
|
||||
"""Create a sequence mask for filtering padding in a sequence tensor.
|
||||
|
||||
Args:
|
||||
|
@ -44,7 +46,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
- mask: :math:`[B, T_max]`
|
||||
"""
|
||||
if max_len is None:
|
||||
max_len = sequence_length.max()
|
||||
max_len = int(sequence_length.max())
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
@ -143,22 +145,75 @@ def convert_pad_shape(pad_shape: list[list]) -> list:
|
|||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate alignment path based on the given segment durations.
|
||||
|
||||
Shapes:
|
||||
- duration: :math:`[B, T_en]`
|
||||
- mask: :math:'[B, T_en, T_de]`
|
||||
- path: :math:`[B, T_en, T_de]`
|
||||
"""
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
cum_duration = torch.cumsum(duration, dim=1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
return path * mask
|
||||
|
||||
|
||||
def generate_attention(
|
||||
duration: torch.Tensor, x_mask: torch.Tensor, y_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Generate an attention map from the linear scale durations.
|
||||
|
||||
Args:
|
||||
duration (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- duration: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = duration.sum(dim=1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = sequence_mask(y_lengths).unsqueeze(1).to(duration.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
return generate_path(duration, attn_mask.squeeze(1)).to(duration.dtype)
|
||||
|
||||
|
||||
def expand_encoder_outputs(
|
||||
x: torch.Tensor, duration: torch.Tensor, x_mask: torch.Tensor, y_lengths: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate attention alignment map from durations and expand encoder outputs.
|
||||
|
||||
Shapes:
|
||||
- x: Encoder output :math:`(B, D_{en}, T_{en})`
|
||||
- duration: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_lengths: :math:`(B)`
|
||||
|
||||
Examples::
|
||||
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
y_mask = sequence_mask(y_lengths).unsqueeze(1).to(x.dtype)
|
||||
attn = generate_attention(duration, x_mask, y_mask)
|
||||
x_expanded = torch.einsum("kmn, kjm -> kjn", [attn.float(), x])
|
||||
return x_expanded, attn, y_mask
|
||||
|
||||
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
from typing import Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def numpy_to_torch(
|
||||
np_array: np.ndarray, dtype: torch.dtype, device: Union[str, torch.device] = "cpu"
|
||||
) -> Optional[torch.Tensor]:
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||
return tensor
|
||||
return torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
|
||||
|
@ -76,18 +75,14 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def id_to_torch(aux_id, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id).to(device)
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
def embedding_to_torch(d_vector, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]:
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||
|
|
|
@ -59,7 +59,7 @@ def _exp(x, base):
|
|||
return np.exp(x)
|
||||
|
||||
|
||||
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray:
|
||||
def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,7 +1,113 @@
|
|||
import logging
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
hann_window = {}
|
||||
mel_basis = {}
|
||||
|
||||
|
||||
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
|
||||
"""Spectral normalization / dynamic range compression."""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * spec_gain)
|
||||
|
||||
|
||||
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
|
||||
"""Spectral denormalization / dynamic range decompression."""
|
||||
return torch.exp(x) / spec_gain
|
||||
|
||||
|
||||
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
|
||||
def spec_to_mel(
|
||||
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
|
||||
Return Shapes:
|
||||
- mel : :math:`[B,C,T]`
|
||||
"""
|
||||
global mel_basis
|
||||
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
# TODO: switch librosa to torchaudio
|
||||
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
return amp_to_db(mel)
|
||||
|
||||
|
||||
def wav_to_mel(
|
||||
y: torch.Tensor,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
sample_rate: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
*,
|
||||
center: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args Shapes:
|
||||
- y : :math:`[B, 1, T]`
|
||||
|
||||
Return Shapes:
|
||||
- spec : :math:`[B,C,T]`
|
||||
"""
|
||||
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
|
||||
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""Some of the audio processing funtions using Torch for faster batch processing.
|
||||
|
@ -157,11 +263,3 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
norm=self.mel_norm,
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
@staticmethod
|
||||
def _amp_to_db(x, spec_gain=1.0):
|
||||
return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
|
||||
|
||||
@staticmethod
|
||||
def _db_to_amp(x, spec_gain=1.0):
|
||||
return torch.exp(x) / spec_gain
|
||||
|
|
|
@ -4,13 +4,26 @@ import importlib
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Callable, Dict, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def exists(val: Union[_T, None]) -> TypeIs[_T]:
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
|
@ -54,25 +67,6 @@ def get_import_path(obj: object) -> str:
|
|||
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
logger.warning("Layer missing in the model finition %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.has("reinit_layers") and c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
|
|
|
@ -6,11 +6,6 @@ from typing import Dict, List, Union
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
|
||||
logger.info("Using model: %s", config.model)
|
||||
# fetch the right model implementation.
|
||||
|
|
|
@ -6,15 +6,15 @@ import numpy as np
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.layers.vits.discriminator import DiscriminatorS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
|
@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights
|
|||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -164,75 +164,6 @@ class Generator(torch.nn.Module):
|
|||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
||||
|
@ -96,37 +96,11 @@ def subsequent_mask(length):
|
|||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
|
|
@ -4,91 +4,16 @@ import torch
|
|||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
logger.info("Max value is: %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
|
@ -128,6 +53,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
|
|||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
spec = amp_to_db(spec)
|
||||
|
||||
return spec
|
||||
|
|
|
@ -5,8 +5,8 @@ from torch.nn import functional as F
|
|||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
|
||||
from TTS.vc.modules.freevc.commons import init_weights
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
|
@ -99,7 +99,7 @@ class WN(torch.nn.Module):
|
|||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
|
|
|
@ -4,14 +4,11 @@ import re
|
|||
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.utils.generic_utils import to_camel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
|
|
|
@ -178,6 +178,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
conv_pre_weight_norm=True,
|
||||
conv_post_weight_norm=True,
|
||||
conv_post_bias=True,
|
||||
cond_in_each_up_layer=False,
|
||||
):
|
||||
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||
|
||||
|
@ -202,6 +203,8 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.inference_padding = inference_padding
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
|
@ -236,6 +239,12 @@ class HifiganGenerator(torch.nn.Module):
|
|||
if not conv_post_weight_norm:
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conds = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.conds.append(nn.Conv1d(cond_channels, ch, 1))
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
|
@ -255,6 +264,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
for i in range(self.num_upsamples):
|
||||
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||
o = self.ups[i](o)
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
o = o + self.conds[i](g)
|
||||
|
||||
z_sum = None
|
||||
for j in range(self.num_kernels):
|
||||
if z_sum is None:
|
||||
|
|
|
@ -12,6 +12,13 @@ from TTS.vocoder.layers.upsample import ConvUpsample
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
|
@ -144,16 +151,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
from TTS.vocoder.models.parallel_wavegan_generator import _get_receptive_field_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -133,17 +134,10 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
|
|
|
@ -17,6 +17,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.layers.upsample import Stretch2d
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
@ -66,19 +67,6 @@ class MelResNet(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Stretch2d(nn.Module):
|
||||
def __init__(self, x_scale, y_scale):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
x = x.unsqueeze(-1).unsqueeze(3)
|
||||
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
||||
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
import torch as T
|
||||
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.helpers import (
|
||||
average_over_durations,
|
||||
expand_encoder_outputs,
|
||||
generate_attention,
|
||||
generate_path,
|
||||
rand_segments,
|
||||
segment,
|
||||
sequence_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_average_over_durations(): # pylint: disable=no-self-use
|
||||
|
@ -86,3 +94,24 @@ def test_generate_path():
|
|||
assert all(path[b, t, :current_idx] == 0.0)
|
||||
assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0)
|
||||
current_idx += durations[b, t].item()
|
||||
|
||||
assert T.all(path == generate_attention(durations, x_mask, y_mask))
|
||||
assert T.all(path == generate_attention(durations, x_mask))
|
||||
|
||||
|
||||
def test_expand_encoder_outputs():
|
||||
inputs = T.rand(2, 5, 57)
|
||||
durations = T.randint(1, 4, (2, 57))
|
||||
|
||||
x_mask = T.ones(2, 1, 57)
|
||||
y_lengths = T.ones(2) * durations.sum(1).max()
|
||||
|
||||
expanded, _, _ = expand_encoder_outputs(inputs, durations, x_mask, y_lengths)
|
||||
|
||||
for b in range(durations.shape[0]):
|
||||
index = 0
|
||||
for idx, dur in enumerate(durations[b]):
|
||||
idx_expanded = expanded[b, :, index : index + dur.item()]
|
||||
diff = (idx_expanded - inputs[b, :, idx].repeat(int(dur)).view(idx_expanded.shape)).sum()
|
||||
assert abs(diff) < 1e-6, diff
|
||||
index += dur
|
|
@ -0,0 +1,16 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.utils.audio import numpy_transforms as np_transforms
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
|
||||
|
||||
|
||||
def test_amplitude_db_conversion():
|
||||
x = torch.rand(11)
|
||||
o1 = amp_to_db(x=x, spec_gain=1.0)
|
||||
o2 = db_to_amp(x=o1, spec_gain=1.0)
|
||||
np_o1 = np_transforms.amp_to_db(x=x, base=np.e)
|
||||
np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e)
|
||||
assert torch.allclose(x, o2)
|
||||
assert torch.allclose(o1, np_o1)
|
||||
assert torch.allclose(o2, np_o2)
|
|
@ -4,7 +4,7 @@ import os
|
|||
import shutil
|
||||
|
||||
import torch
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import shutil
|
||||
|
||||
import torch
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.overflow_config import OverflowConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
|
|
|
@ -2,7 +2,7 @@ import glob
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
|
|
|
@ -13,14 +13,10 @@ from TTS.tts.models.vits import (
|
|||
Vits,
|
||||
VitsArgs,
|
||||
VitsAudioConfig,
|
||||
amp_to_db,
|
||||
db_to_amp,
|
||||
load_audio,
|
||||
spec_to_mel,
|
||||
wav_to_mel,
|
||||
wav_to_spec,
|
||||
)
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.align_tts_config import AlignTTSConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
|
|
|
@ -6,29 +6,7 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
# pylint: disable=unused-variable
|
||||
|
||||
|
||||
def expand_encoder_outputs_test():
|
||||
model = ForwardTTS(ForwardTTSArgs(num_chars=10))
|
||||
|
||||
inputs = T.rand(2, 5, 57)
|
||||
durations = T.randint(1, 4, (2, 57))
|
||||
|
||||
x_mask = T.ones(2, 1, 57)
|
||||
y_mask = T.ones(2, 1, durations.sum(1).max())
|
||||
|
||||
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
|
||||
|
||||
for b in range(durations.shape[0]):
|
||||
index = 0
|
||||
for idx, dur in enumerate(durations[b]):
|
||||
diff = (
|
||||
expanded[b, :, index : index + dur.item()]
|
||||
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
|
||||
).sum()
|
||||
assert abs(diff) < 1e-6, diff
|
||||
index += dur
|
||||
|
||||
|
||||
def model_input_output_test():
|
||||
def test_model_input_output():
|
||||
"""Assert the output shapes of the model in different modes"""
|
||||
|
||||
# VANILLA MODEL
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
from trainer.io import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
|
Loading…
Reference in New Issue