mirror of https://github.com/coqui-ai/TTS.git
Add Delightful-TTS implementation (#2095)
* add configs * Update config file * Add model configs * Add model layers * Add layer files * Add layer modules * change config names * Add emotion manager * fIX missing ap bug * Fix missing ap bug * Add base TTS e2e class * Fix wrong variable name in load_tts_samples * Add training script * Remove range predictor and gaussian upsampling * Add helper function * Add vctk recipe * Add conformer docs * Fix linting in conformer.py * Add Docs * remove duplicate import * refactor args * Fix bugs * Removew emotion embedding * remove unused arg * Remove emotion embedding arg * Remove emotion embedding arg * fix style issues * Fix bugs * Fix bugs * Add unittests * make style * fix formatter bug * fix test * Add pyworld compute pitch func * Update requirments.txt * Fix dataset Bug * Chnge layer norm to instance norm * Add missing import * Remove emotions.py * remove ssim loss * Add init layers func to aligner * refactor model layers * remove audio_config arg * Rename loss func * Rename to delightful-tts * Rename loss func * Remove unused modules * refactor imports * replace audio config with audio processor * Add change sample rate option * remove broken resample func * update recipe * fix style, add config docs * fix tests and multispeaker embd dim * remove pyworld * Make style and fix inference * Split tts tests * Fixup * Fixup * Fixup * Add argument names * Set "random" speaker in the model Tortoise/Bark * Use a diff f0_cache path for delightfull tts * Fix delightful speaker handling * Fix lint * Make style --------- Co-authored-by: loganhart420 <loganartpersonal@gmail.com> Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
parent
f24c5e0276
commit
6fdb88f8e2
|
@ -0,0 +1,53 @@
|
|||
name: tts-tests2
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_tts2
|
3
Makefile
3
Makefile
|
@ -19,6 +19,9 @@ test_vocoder: ## run vocoder tests.
|
|||
test_tts: ## run tts tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests
|
||||
|
||||
test_tts2: ## run tts tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2
|
||||
|
||||
test_aux: ## run aux tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
|
||||
./run_bash_tests.sh
|
||||
|
|
|
@ -430,9 +430,9 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
if tts_path is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
args.speaker_idx,
|
||||
args.language_idx,
|
||||
args.speaker_wav,
|
||||
speaker_name=args.speaker_idx,
|
||||
language_name=args.language_idx,
|
||||
speaker_wav=args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelightfulTTSConfig(BaseTTSConfig):
|
||||
"""
|
||||
Configuration class for the DelightfulTTS model.
|
||||
|
||||
Attributes:
|
||||
model (str): Name of the model ("delightful_tts").
|
||||
audio (DelightfulTtsAudioConfig): Configuration for audio settings.
|
||||
model_args (DelightfulTtsArgs): Configuration for model arguments.
|
||||
use_attn_priors (bool): Whether to use attention priors.
|
||||
vocoder (VocoderConfig): Configuration for the vocoder.
|
||||
init_discriminator (bool): Whether to initialize the discriminator.
|
||||
steps_to_start_discriminator (int): Number of steps to start the discriminator.
|
||||
grad_clip (List[float]): Gradient clipping values.
|
||||
lr_gen (float): Learning rate for the gan generator.
|
||||
lr_disc (float): Learning rate for the gan discriminator.
|
||||
lr_scheduler_gen (str): Name of the learning rate scheduler for the generator.
|
||||
lr_scheduler_gen_params (dict): Parameters for the learning rate scheduler for the generator.
|
||||
lr_scheduler_disc (str): Name of the learning rate scheduler for the discriminator.
|
||||
lr_scheduler_disc_params (dict): Parameters for the learning rate scheduler for the discriminator.
|
||||
scheduler_after_epoch (bool): Whether to schedule after each epoch.
|
||||
optimizer (str): Name of the optimizer.
|
||||
optimizer_params (dict): Parameters for the optimizer.
|
||||
ssim_loss_alpha (float): Alpha value for the SSIM loss.
|
||||
mel_loss_alpha (float): Alpha value for the mel loss.
|
||||
aligner_loss_alpha (float): Alpha value for the aligner loss.
|
||||
pitch_loss_alpha (float): Alpha value for the pitch loss.
|
||||
energy_loss_alpha (float): Alpha value for the energy loss.
|
||||
u_prosody_loss_alpha (float): Alpha value for the utterance prosody loss.
|
||||
p_prosody_loss_alpha (float): Alpha value for the phoneme prosody loss.
|
||||
dur_loss_alpha (float): Alpha value for the duration loss.
|
||||
char_dur_loss_alpha (float): Alpha value for the character duration loss.
|
||||
binary_align_loss_alpha (float): Alpha value for the binary alignment loss.
|
||||
binary_loss_warmup_epochs (int): Number of warm-up epochs for the binary loss.
|
||||
disc_loss_alpha (float): Alpha value for the discriminator loss.
|
||||
gen_loss_alpha (float): Alpha value for the generator loss.
|
||||
feat_loss_alpha (float): Alpha value for the feature loss.
|
||||
vocoder_mel_loss_alpha (float): Alpha value for the vocoder mel loss.
|
||||
multi_scale_stft_loss_alpha (float): Alpha value for the multi-scale STFT loss.
|
||||
multi_scale_stft_loss_params (dict): Parameters for the multi-scale STFT loss.
|
||||
return_wav (bool): Whether to return audio waveforms.
|
||||
use_weighted_sampler (bool): Whether to use a weighted sampler.
|
||||
weighted_sampler_attrs (dict): Attributes for the weighted sampler.
|
||||
weighted_sampler_multipliers (dict): Multipliers for the weighted sampler.
|
||||
r (int): Value for the `r` override.
|
||||
compute_f0 (bool): Whether to compute F0 values.
|
||||
f0_cache_path (str): Path to the F0 cache.
|
||||
attn_prior_cache_path (str): Path to the attention prior cache.
|
||||
num_speakers (int): Number of speakers.
|
||||
use_speaker_embedding (bool): Whether to use speaker embedding.
|
||||
speakers_file (str): Path to the speaker file.
|
||||
speaker_embedding_channels (int): Number of channels for the speaker embedding.
|
||||
language_ids_file (str): Path to the language IDs file.
|
||||
"""
|
||||
|
||||
model: str = "delightful_tts"
|
||||
|
||||
# model specific params
|
||||
audio: DelightfulTtsAudioConfig = field(default_factory=DelightfulTtsAudioConfig)
|
||||
model_args: DelightfulTtsArgs = field(default_factory=DelightfulTtsArgs)
|
||||
use_attn_priors: bool = True
|
||||
|
||||
# vocoder
|
||||
vocoder: VocoderConfig = field(default_factory=VocoderConfig)
|
||||
init_discriminator: bool = True
|
||||
|
||||
# optimizer
|
||||
steps_to_start_discriminator: int = 200000
|
||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||
lr_gen: float = 0.0002
|
||||
lr_disc: float = 0.0002
|
||||
lr_scheduler_gen: str = "ExponentialLR"
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "ExponentialLR"
|
||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
|
||||
scheduler_after_epoch: bool = True
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
|
||||
|
||||
# acoustic model loss params
|
||||
ssim_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
pitch_loss_alpha: float = 1.0
|
||||
energy_loss_alpha: float = 1.0
|
||||
u_prosody_loss_alpha: float = 0.5
|
||||
p_prosody_loss_alpha: float = 0.5
|
||||
dur_loss_alpha: float = 1.0
|
||||
char_dur_loss_alpha: float = 0.01
|
||||
binary_align_loss_alpha: float = 0.1
|
||||
binary_loss_warmup_epochs: int = 10
|
||||
|
||||
# vocoder loss params
|
||||
disc_loss_alpha: float = 1.0
|
||||
gen_loss_alpha: float = 1.0
|
||||
feat_loss_alpha: float = 1.0
|
||||
vocoder_mel_loss_alpha: float = 10.0
|
||||
multi_scale_stft_loss_alpha: float = 2.5
|
||||
multi_scale_stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
use_weighted_sampler: bool = False
|
||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
# overrides
|
||||
r: int = 1
|
||||
|
||||
# dataset configs
|
||||
compute_f0: bool = True
|
||||
f0_cache_path: str = None
|
||||
attn_prior_cache_path: str = None
|
||||
|
||||
# multi-speaker settings
|
||||
# use speaker embedding layer
|
||||
num_speakers: int = 0
|
||||
use_speaker_embedding: bool = False
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
language_ids_file: str = None
|
||||
use_language_embedding: bool = False
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = None
|
||||
d_vector_dim: int = None
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||
if self.num_speakers > 0:
|
||||
self.model_args.num_speakers = self.num_speakers
|
||||
|
||||
# speaker embedding settings
|
||||
if self.use_speaker_embedding:
|
||||
self.model_args.use_speaker_embedding = True
|
||||
if self.speakers_file:
|
||||
self.model_args.speakers_file = self.speakers_file
|
||||
|
||||
# d-vector settings
|
||||
if self.use_d_vector_file:
|
||||
self.model_args.use_d_vector_file = True
|
||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||
self.model_args.d_vector_dim = self.d_vector_dim
|
||||
if self.d_vector_file:
|
||||
self.model_args.d_vector_file = self.d_vector_file
|
|
@ -686,6 +686,7 @@ class F0Dataset:
|
|||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
ap: "AudioProcessor",
|
||||
audio_config=None, # pylint: disable=unused-argument
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
|
|
|
@ -0,0 +1,563 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conformer import Conformer
|
||||
from TTS.tts.layers.delightful_tts.encoders import (
|
||||
PhonemeLevelProsodyEncoder,
|
||||
UtteranceLevelProsodyEncoder,
|
||||
get_mask_from_lengths,
|
||||
)
|
||||
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor
|
||||
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding
|
||||
from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor
|
||||
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor
|
||||
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
||||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
|
||||
|
||||
class AcousticModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
args: "ModelArgs",
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: "SpeakerManager" = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
self.speaker_manager = speaker_manager
|
||||
|
||||
self.init_multispeaker(args)
|
||||
# self.set_embedding_dims()
|
||||
|
||||
self.length_scale = (
|
||||
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||
)
|
||||
|
||||
self.emb_dim = args.n_hidden_conformer_encoder
|
||||
self.encoder = Conformer(
|
||||
dim=self.args.n_hidden_conformer_encoder,
|
||||
n_layers=self.args.n_layers_conformer_encoder,
|
||||
n_heads=self.args.n_heads_conformer_encoder,
|
||||
speaker_embedding_dim=self.embedded_speaker_dim,
|
||||
p_dropout=self.args.dropout_conformer_encoder,
|
||||
kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_encoder,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
self.pitch_adaptor = PitchAdaptor(
|
||||
n_input=self.args.n_hidden_conformer_encoder,
|
||||
n_hidden=self.args.n_hidden_variance_adaptor,
|
||||
n_out=1,
|
||||
kernel_size=self.args.kernel_size_variance_adaptor,
|
||||
emb_kernel_size=self.args.emb_kernel_size_variance_adaptor,
|
||||
p_dropout=self.args.dropout_variance_adaptor,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
self.energy_adaptor = EnergyAdaptor(
|
||||
channels_in=self.args.n_hidden_conformer_encoder,
|
||||
channels_hidden=self.args.n_hidden_variance_adaptor,
|
||||
channels_out=1,
|
||||
kernel_size=self.args.kernel_size_variance_adaptor,
|
||||
emb_kernel_size=self.args.emb_kernel_size_variance_adaptor,
|
||||
dropout=self.args.dropout_variance_adaptor,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
|
||||
self.aligner = AlignmentNetwork(
|
||||
in_query_channels=self.args.out_channels,
|
||||
in_key_channels=self.args.n_hidden_conformer_encoder,
|
||||
)
|
||||
|
||||
self.duration_predictor = VariancePredictor(
|
||||
channels_in=self.args.n_hidden_conformer_encoder,
|
||||
channels=self.args.n_hidden_variance_adaptor,
|
||||
channels_out=1,
|
||||
kernel_size=self.args.kernel_size_variance_adaptor,
|
||||
p_dropout=self.args.dropout_variance_adaptor,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
|
||||
self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder(
|
||||
num_mels=self.args.num_mels,
|
||||
ref_enc_filters=self.args.ref_enc_filters_reference_encoder,
|
||||
ref_enc_size=self.args.ref_enc_size_reference_encoder,
|
||||
ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder,
|
||||
ref_enc_strides=self.args.ref_enc_strides_reference_encoder,
|
||||
n_hidden=self.args.n_hidden_conformer_encoder,
|
||||
dropout=self.args.dropout_conformer_encoder,
|
||||
bottleneck_size_u=self.args.bottleneck_size_u_reference_encoder,
|
||||
token_num=self.args.token_num_reference_encoder,
|
||||
)
|
||||
|
||||
self.utterance_prosody_predictor = PhonemeProsodyPredictor(
|
||||
hidden_size=self.args.n_hidden_conformer_encoder,
|
||||
kernel_size=self.args.predictor_kernel_size_reference_encoder,
|
||||
dropout=self.args.dropout_conformer_encoder,
|
||||
bottleneck_size=self.args.bottleneck_size_u_reference_encoder,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
|
||||
self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder(
|
||||
num_mels=self.args.num_mels,
|
||||
ref_enc_filters=self.args.ref_enc_filters_reference_encoder,
|
||||
ref_enc_size=self.args.ref_enc_size_reference_encoder,
|
||||
ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder,
|
||||
ref_enc_strides=self.args.ref_enc_strides_reference_encoder,
|
||||
n_hidden=self.args.n_hidden_conformer_encoder,
|
||||
dropout=self.args.dropout_conformer_encoder,
|
||||
bottleneck_size_p=self.args.bottleneck_size_p_reference_encoder,
|
||||
n_heads=self.args.n_heads_conformer_encoder,
|
||||
)
|
||||
|
||||
self.phoneme_prosody_predictor = PhonemeProsodyPredictor(
|
||||
hidden_size=self.args.n_hidden_conformer_encoder,
|
||||
kernel_size=self.args.predictor_kernel_size_reference_encoder,
|
||||
dropout=self.args.dropout_conformer_encoder,
|
||||
bottleneck_size=self.args.bottleneck_size_p_reference_encoder,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
|
||||
self.u_bottle_out = nn.Linear(
|
||||
self.args.bottleneck_size_u_reference_encoder,
|
||||
self.args.n_hidden_conformer_encoder,
|
||||
)
|
||||
|
||||
self.u_norm = nn.InstanceNorm1d(self.args.bottleneck_size_u_reference_encoder)
|
||||
self.p_bottle_out = nn.Linear(
|
||||
self.args.bottleneck_size_p_reference_encoder,
|
||||
self.args.n_hidden_conformer_encoder,
|
||||
)
|
||||
self.p_norm = nn.InstanceNorm1d(
|
||||
self.args.bottleneck_size_p_reference_encoder,
|
||||
)
|
||||
self.decoder = Conformer(
|
||||
dim=self.args.n_hidden_conformer_decoder,
|
||||
n_layers=self.args.n_layers_conformer_decoder,
|
||||
n_heads=self.args.n_heads_conformer_decoder,
|
||||
speaker_embedding_dim=self.embedded_speaker_dim,
|
||||
p_dropout=self.args.dropout_conformer_decoder,
|
||||
kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_decoder,
|
||||
lrelu_slope=self.args.lrelu_slope,
|
||||
)
|
||||
|
||||
padding_idx = self.tokenizer.characters.pad_id
|
||||
self.src_word_emb = EmbeddingPadded(
|
||||
self.args.num_chars, self.args.n_hidden_conformer_encoder, padding_idx=padding_idx
|
||||
)
|
||||
self.to_mel = nn.Linear(
|
||||
self.args.n_hidden_conformer_decoder,
|
||||
self.args.num_mels,
|
||||
)
|
||||
|
||||
self.energy_scaler = torch.nn.BatchNorm1d(1, affine=False, track_running_stats=True, momentum=None)
|
||||
self.energy_scaler.requires_grad_(False)
|
||||
|
||||
def init_multispeaker(self, args: Coqpit): # pylint: disable=unused-argument
|
||||
"""Init for multi-speaker training."""
|
||||
self.embedded_speaker_dim = 0
|
||||
self.num_speakers = self.args.num_speakers
|
||||
self.audio_transform = None
|
||||
|
||||
if self.speaker_manager:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
|
||||
if self.args.use_speaker_embedding:
|
||||
self._init_speaker_embedding()
|
||||
|
||||
if self.args.use_d_vector_file:
|
||||
self._init_d_vector()
|
||||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g, lid, durations = None, None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = F.normalize(aux_input["d_vectors"]) # .unsqueeze_(-1)
|
||||
if g.ndim == 2:
|
||||
g = g # .unsqueeze_(0) # pylint: disable=self-assigning-variable
|
||||
|
||||
if "durations" in aux_input and aux_input["durations"] is not None:
|
||||
durations = aux_input["durations"]
|
||||
|
||||
return sid, g, lid, durations
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
speaker_ids = aux_input.get("speaker_ids", None)
|
||||
|
||||
if d_vectors is not None and speaker_ids is not None:
|
||||
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
||||
|
||||
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
||||
|
||||
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||
return g
|
||||
|
||||
# def set_embedding_dims(self):
|
||||
# if self.embedded_speaker_dim > 0:
|
||||
# self.embedding_dims = self.embedded_speaker_dim
|
||||
# else:
|
||||
# self.embedding_dims = 0
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
self.embedded_speaker_dim = self.args.d_vector_dim
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the linear scale durations.
|
||||
|
||||
Args:
|
||||
dr (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
- x_mask: :math:`(B, T_{en})`
|
||||
- y_mask: :math:`(B, T_{de})`
|
||||
"""
|
||||
# compute decode mask from the durations
|
||||
if y_mask is None:
|
||||
y_lengths = dr.sum(1).long()
|
||||
y_lengths[y_lengths < 1] = 1
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def _expand_encoder_with_durations(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
dr: torch.IntTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
y_lengths: torch.IntTensor,
|
||||
):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en])
|
||||
return y_mask, o_en_ex, attn.transpose(1, 2)
|
||||
|
||||
def _forward_aligner(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
y: torch.FloatTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
y_mask: torch.IntTensor,
|
||||
attn_priors: torch.FloatTensor,
|
||||
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Aligner forward pass.
|
||||
|
||||
1. Compute a mask to apply to the attention map.
|
||||
2. Run the alignment network.
|
||||
3. Apply MAS to compute the hard alignment map.
|
||||
4. Compute the durations from the hard alignment map.
|
||||
|
||||
Args:
|
||||
x (torch.FloatTensor): Input sequence.
|
||||
y (torch.FloatTensor): Output sequence.
|
||||
x_mask (torch.IntTensor): Input sequence mask.
|
||||
y_mask (torch.IntTensor): Output sequence mask.
|
||||
attn_priors (torch.FloatTensor): Prior for the aligner network map.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
|
||||
hard alignment map.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, T_en, C_en]`
|
||||
- y: :math:`[B, T_de, C_de]`
|
||||
- x_mask: :math:`[B, 1, T_en]`
|
||||
- y_mask: :math:`[B, 1, T_de]`
|
||||
- attn_priors: :math:`[B, T_de, T_en]`
|
||||
|
||||
- aligner_durations: :math:`[B, T_en]`
|
||||
- aligner_soft: :math:`[B, T_de, T_en]`
|
||||
- aligner_logprob: :math:`[B, 1, T_de, T_en]`
|
||||
- aligner_mas: :math:`[B, T_de, T_en]`
|
||||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # [B, 1, T_en, T_de]
|
||||
aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors)
|
||||
aligner_mas = maximum_path(
|
||||
aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||
)
|
||||
aligner_durations = torch.sum(aligner_mas, -1).int()
|
||||
aligner_soft = aligner_soft.squeeze(1) # [B, T_max2, T_max]
|
||||
aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max]
|
||||
return aligner_durations, aligner_soft, aligner_logprob, aligner_mas
|
||||
|
||||
def average_utterance_prosody( # pylint: disable=no-self-use
|
||||
self, u_prosody_pred: torch.Tensor, src_mask: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
lengths = ((~src_mask) * 1.0).sum(1)
|
||||
u_prosody_pred = u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1)
|
||||
return u_prosody_pred
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
src_lens: torch.Tensor,
|
||||
mels: torch.Tensor,
|
||||
mel_lens: torch.Tensor,
|
||||
pitches: torch.Tensor,
|
||||
energies: torch.Tensor,
|
||||
attn_priors: torch.Tensor,
|
||||
use_ground_truth: bool = True,
|
||||
d_vectors: torch.Tensor = None,
|
||||
speaker_idx: torch.Tensor = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
|
||||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
) # pylint: disable=unused-variable
|
||||
|
||||
src_mask = get_mask_from_lengths(src_lens) # [B, T_src]
|
||||
mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel]
|
||||
|
||||
# Token embeddings
|
||||
token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden]
|
||||
token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0)
|
||||
|
||||
# Alignment network and durations
|
||||
aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner(
|
||||
x=token_embeddings,
|
||||
y=mels.transpose(1, 2),
|
||||
x_mask=~src_mask[:, None],
|
||||
y_mask=~mel_mask[:, None],
|
||||
attn_priors=attn_priors,
|
||||
)
|
||||
dr = aligner_durations # [B, T_en]
|
||||
|
||||
# Embeddings
|
||||
speaker_embedding = None
|
||||
if d_vectors is not None:
|
||||
speaker_embedding = g
|
||||
elif speaker_idx is not None:
|
||||
speaker_embedding = F.normalize(self.emb_g(sid))
|
||||
|
||||
pos_encoding = positional_encoding(
|
||||
self.emb_dim,
|
||||
max(token_embeddings.shape[1], max(mel_lens)),
|
||||
device=token_embeddings.device,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
token_embeddings,
|
||||
src_mask,
|
||||
speaker_embedding=speaker_embedding,
|
||||
encoding=pos_encoding,
|
||||
)
|
||||
|
||||
u_prosody_ref = self.u_norm(self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens))
|
||||
u_prosody_pred = self.u_norm(
|
||||
self.average_utterance_prosody(
|
||||
u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask),
|
||||
src_mask=src_mask,
|
||||
)
|
||||
)
|
||||
|
||||
if use_ground_truth:
|
||||
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_ref)
|
||||
else:
|
||||
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred)
|
||||
|
||||
p_prosody_ref = self.p_norm(
|
||||
self.phoneme_prosody_encoder(
|
||||
x=encoder_outputs, src_mask=src_mask, mels=mels, mel_lens=mel_lens, encoding=pos_encoding
|
||||
)
|
||||
)
|
||||
p_prosody_pred = self.p_norm(self.phoneme_prosody_predictor(x=encoder_outputs, mask=src_mask))
|
||||
|
||||
if use_ground_truth:
|
||||
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_ref)
|
||||
else:
|
||||
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred)
|
||||
|
||||
encoder_outputs_res = encoder_outputs
|
||||
|
||||
pitch_pred, avg_pitch_target, pitch_emb = self.pitch_adaptor.get_pitch_embedding_train(
|
||||
x=encoder_outputs,
|
||||
target=pitches,
|
||||
dr=dr,
|
||||
mask=src_mask,
|
||||
)
|
||||
|
||||
energy_pred, avg_energy_target, energy_emb = self.energy_adaptor.get_energy_embedding_train(
|
||||
x=encoder_outputs,
|
||||
target=energies,
|
||||
dr=dr,
|
||||
mask=src_mask,
|
||||
)
|
||||
|
||||
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb
|
||||
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask)
|
||||
|
||||
mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
x = self.decoder(
|
||||
encoder_outputs_ex.transpose(1, 2),
|
||||
mel_mask,
|
||||
speaker_embedding=speaker_embedding,
|
||||
encoding=pos_encoding,
|
||||
)
|
||||
x = self.to_mel(x)
|
||||
|
||||
dr = torch.log(dr + 1)
|
||||
|
||||
dr_pred = torch.exp(log_duration_prediction) - 1
|
||||
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
|
||||
|
||||
return {
|
||||
"model_outputs": x,
|
||||
"pitch_pred": pitch_pred,
|
||||
"pitch_target": avg_pitch_target,
|
||||
"energy_pred": energy_pred,
|
||||
"energy_target": avg_energy_target,
|
||||
"u_prosody_pred": u_prosody_pred,
|
||||
"u_prosody_ref": u_prosody_ref,
|
||||
"p_prosody_pred": p_prosody_pred,
|
||||
"p_prosody_ref": p_prosody_ref,
|
||||
"alignments_dp": alignments_dp,
|
||||
"alignments": alignments, # [B, T_de, T_en]
|
||||
"aligner_soft": aligner_soft,
|
||||
"aligner_mas": aligner_mas,
|
||||
"aligner_durations": aligner_durations,
|
||||
"aligner_logprob": aligner_logprob,
|
||||
"dr_log_pred": log_duration_prediction.squeeze(1), # [B, T]
|
||||
"dr_log_target": dr.squeeze(1), # [B, T]
|
||||
"spk_emb": speaker_embedding,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
speaker_idx: torch.Tensor,
|
||||
p_control: float = None, # TODO # pylint: disable=unused-argument
|
||||
d_control: float = None, # TODO # pylint: disable=unused-argument
|
||||
d_vectors: torch.Tensor = None,
|
||||
pitch_transform: Callable = None,
|
||||
energy_transform: Callable = None,
|
||||
) -> torch.Tensor:
|
||||
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
|
||||
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable
|
||||
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
|
||||
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
|
||||
) # pylint: disable=unused-variable
|
||||
|
||||
token_embeddings = self.src_word_emb(tokens)
|
||||
token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0)
|
||||
|
||||
# Embeddings
|
||||
speaker_embedding = None
|
||||
if d_vectors is not None:
|
||||
speaker_embedding = g
|
||||
elif speaker_idx is not None:
|
||||
speaker_embedding = F.normalize(self.emb_g(sid))
|
||||
|
||||
pos_encoding = positional_encoding(
|
||||
self.emb_dim,
|
||||
token_embeddings.shape[1],
|
||||
device=token_embeddings.device,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
token_embeddings,
|
||||
src_mask,
|
||||
speaker_embedding=speaker_embedding,
|
||||
encoding=pos_encoding,
|
||||
)
|
||||
|
||||
u_prosody_pred = self.u_norm(
|
||||
self.average_utterance_prosody(
|
||||
u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask),
|
||||
src_mask=src_mask,
|
||||
)
|
||||
)
|
||||
encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred).expand_as(encoder_outputs)
|
||||
|
||||
p_prosody_pred = self.p_norm(
|
||||
self.phoneme_prosody_predictor(
|
||||
x=encoder_outputs,
|
||||
mask=src_mask,
|
||||
)
|
||||
)
|
||||
encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred).expand_as(encoder_outputs)
|
||||
|
||||
encoder_outputs_res = encoder_outputs
|
||||
|
||||
pitch_emb_pred, pitch_pred = self.pitch_adaptor.get_pitch_embedding(
|
||||
x=encoder_outputs,
|
||||
mask=src_mask,
|
||||
pitch_transform=pitch_transform,
|
||||
pitch_mean=self.pitch_mean if hasattr(self, "pitch_mean") else None,
|
||||
pitch_std=self.pitch_std if hasattr(self, "pitch_std") else None,
|
||||
)
|
||||
|
||||
energy_emb_pred, energy_pred = self.energy_adaptor.get_energy_embedding(
|
||||
x=encoder_outputs, mask=src_mask, energy_transform=energy_transform
|
||||
)
|
||||
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb_pred + energy_emb_pred
|
||||
|
||||
log_duration_pred = self.duration_predictor(
|
||||
x=encoder_outputs_res.detach(), mask=src_mask
|
||||
) # [B, C_hidden, T_src] -> [B, T_src]
|
||||
duration_pred = (torch.exp(log_duration_pred) - 1) * (~src_mask) * self.length_scale # -> [B, T_src]
|
||||
duration_pred[duration_pred < 1] = 1.0 # -> [B, T_src]
|
||||
duration_pred = torch.round(duration_pred) # -> [B, T_src]
|
||||
mel_lens = duration_pred.sum(1) # -> [B,]
|
||||
|
||||
_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
|
||||
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
|
||||
)
|
||||
|
||||
mel_mask = get_mask_from_lengths(
|
||||
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device)
|
||||
)
|
||||
|
||||
if encoder_outputs_ex.shape[1] > pos_encoding.shape[1]:
|
||||
encoding = positional_encoding(self.emb_dim, encoder_outputs_ex.shape[2], device=tokens.device)
|
||||
|
||||
# [B, C_hidden, T_src], [B, 1, T_src], [B, C_emb], [B, T_src, C_hidden] -> [B, C_hidden, T_src]
|
||||
x = self.decoder(
|
||||
encoder_outputs_ex.transpose(1, 2),
|
||||
mel_mask,
|
||||
speaker_embedding=speaker_embedding,
|
||||
encoding=encoding,
|
||||
)
|
||||
x = self.to_mel(x)
|
||||
outputs = {
|
||||
"model_outputs": x,
|
||||
"alignments": alignments,
|
||||
# "pitch": pitch_emb_pred,
|
||||
"durations": duration_pred,
|
||||
"pitch": pitch_pred,
|
||||
"energy": energy_pred,
|
||||
"spk_emb": speaker_embedding,
|
||||
}
|
||||
return outputs
|
|
@ -0,0 +1,450 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d
|
||||
from TTS.tts.layers.delightful_tts.networks import GLUActivation
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
class Conformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_layers: int,
|
||||
n_heads: int,
|
||||
speaker_embedding_dim: int,
|
||||
p_dropout: float,
|
||||
kernel_size_conv_mod: int,
|
||||
lrelu_slope: float,
|
||||
):
|
||||
"""
|
||||
A Transformer variant that integrates both CNNs and Transformers components.
|
||||
Conformer proposes a novel combination of self-attention and convolution, in which self-attention
|
||||
learns the global interaction while the convolutions efficiently capture the local correlations.
|
||||
|
||||
Args:
|
||||
dim (int): Number of the dimensions for the model.
|
||||
n_layers (int): Number of model layers.
|
||||
n_heads (int): The number of attention heads.
|
||||
speaker_embedding_dim (int): Number of speaker embedding dimensions.
|
||||
p_dropout (float): Probabilty of dropout.
|
||||
kernel_size_conv_mod (int): Size of kernels for convolution modules.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
- **encoding** (batch, time, dim): Positional embedding tensor
|
||||
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produced by Conformer Encoder.
|
||||
"""
|
||||
super().__init__()
|
||||
d_k = d_v = dim // n_heads
|
||||
self.layer_stack = nn.ModuleList(
|
||||
[
|
||||
ConformerBlock(
|
||||
dim,
|
||||
n_heads,
|
||||
d_k,
|
||||
d_v,
|
||||
kernel_size_conv_mod=kernel_size_conv_mod,
|
||||
dropout=p_dropout,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
lrelu_slope=lrelu_slope,
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
speaker_embedding: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T_src, C]`
|
||||
- mask: :math: `[B]`
|
||||
- speaker_embedding: :math: `[B, C]`
|
||||
- encoding: :math: `[B, T_max2, C]`
|
||||
"""
|
||||
|
||||
attn_mask = mask.view((mask.shape[0], 1, 1, mask.shape[1]))
|
||||
for enc_layer in self.layer_stack:
|
||||
x = enc_layer(
|
||||
x,
|
||||
mask=mask,
|
||||
slf_attn_mask=attn_mask,
|
||||
speaker_embedding=speaker_embedding,
|
||||
encoding=encoding,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class ConformerBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
d_k: int, # pylint: disable=unused-argument
|
||||
d_v: int, # pylint: disable=unused-argument
|
||||
kernel_size_conv_mod: int,
|
||||
speaker_embedding_dim: int,
|
||||
dropout: float,
|
||||
lrelu_slope: float = 0.3,
|
||||
):
|
||||
"""
|
||||
A Conformer block is composed of four modules stacked together,
|
||||
A feed-forward module, a self-attention module, a convolution module,
|
||||
and a second feed-forward module in the end. The block starts with two Feed forward
|
||||
modules sandwiching the Multi-Headed Self-Attention module and the Conv module.
|
||||
|
||||
Args:
|
||||
d_model (int): The dimension of model
|
||||
n_head (int): The number of attention heads.
|
||||
kernel_size_conv_mod (int): Size of kernels for convolution modules.
|
||||
speaker_embedding_dim (int): Number of speaker embedding dimensions.
|
||||
emotion_embedding_dim (int): Number of emotion embedding dimensions.
|
||||
dropout (float): Probabilty of dropout.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
- **encoding** (batch, time, dim): Positional embedding tensor
|
||||
- **slf_attn_mask** (batch, 1, 1, time1): Tensor containing indices to be masked in self attention module
|
||||
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produced by the Conformer Block.
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(speaker_embedding_dim, int):
|
||||
self.conditioning = Conv1dGLU(
|
||||
d_model=d_model,
|
||||
kernel_size=kernel_size_conv_mod,
|
||||
padding=kernel_size_conv_mod // 2,
|
||||
embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(d_model=d_model, dropout=dropout, kernel_size=3, lrelu_slope=lrelu_slope)
|
||||
self.conformer_conv_1 = ConformerConvModule(
|
||||
d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope
|
||||
)
|
||||
self.ln = nn.LayerNorm(d_model)
|
||||
self.slf_attn = ConformerMultiHeadedSelfAttention(d_model=d_model, num_heads=n_head, dropout_p=dropout)
|
||||
self.conformer_conv_2 = ConformerConvModule(
|
||||
d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
speaker_embedding: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
slf_attn_mask: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T_src, C]`
|
||||
- mask: :math: `[B]`
|
||||
- slf_attn_mask: :math: `[B, 1, 1, T_src]`
|
||||
- speaker_embedding: :math: `[B, C]`
|
||||
- emotion_embedding: :math: `[B, C]`
|
||||
- encoding: :math: `[B, T_max2, C]`
|
||||
"""
|
||||
if speaker_embedding is not None:
|
||||
x = self.conditioning(x, embeddings=speaker_embedding)
|
||||
x = self.ff(x) + x
|
||||
x = self.conformer_conv_1(x) + x
|
||||
res = x
|
||||
x = self.ln(x)
|
||||
x, _ = self.slf_attn(query=x, key=x, value=x, mask=slf_attn_mask, encoding=encoding)
|
||||
x = x + res
|
||||
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
x = self.conformer_conv_2(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
kernel_size: int,
|
||||
dropout: float,
|
||||
lrelu_slope: float,
|
||||
expansion_factor: int = 4,
|
||||
):
|
||||
"""
|
||||
Feed Forward module for conformer block.
|
||||
|
||||
Args:
|
||||
d_model (int): The dimension of model.
|
||||
kernel_size (int): Size of the kernels for conv layers.
|
||||
dropout (float): probability of dropout.
|
||||
expansion_factor (int): The factor by which to project the number of channels.
|
||||
lrelu_slope (int): the negative slope factor for the leaky relu activation.
|
||||
|
||||
Inputs: inputs
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produced by the feed forward module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.ln = nn.LayerNorm(d_model)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
d_model,
|
||||
d_model * expansion_factor,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
self.act = nn.LeakyReLU(lrelu_slope)
|
||||
self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T, C]`
|
||||
"""
|
||||
x = self.ln(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = self.conv_1(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = self.conv_2(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = self.dropout(x)
|
||||
x = 0.5 * x
|
||||
return x
|
||||
|
||||
|
||||
class ConformerConvModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
expansion_factor: int = 2,
|
||||
kernel_size: int = 7,
|
||||
dropout: float = 0.1,
|
||||
lrelu_slope: float = 0.3,
|
||||
):
|
||||
"""
|
||||
Convolution module for conformer. Starts with a gating machanism.
|
||||
a pointwise convolution and a gated linear unit (GLU). This is followed
|
||||
by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
||||
to help with training. it also contains an expansion factor to project the number of channels.
|
||||
|
||||
Args:
|
||||
d_model (int): The dimension of model.
|
||||
expansion_factor (int): The factor by which to project the number of channels.
|
||||
kernel_size (int): Size of kernels for convolution modules.
|
||||
dropout (float): Probabilty of dropout.
|
||||
lrelu_slope (float): The slope coefficient for leaky relu activation.
|
||||
|
||||
Inputs: inputs
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produced by the conv module.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
inner_dim = d_model * expansion_factor
|
||||
self.ln_1 = nn.LayerNorm(d_model)
|
||||
self.conv_1 = PointwiseConv1d(d_model, inner_dim * 2)
|
||||
self.conv_act = GLUActivation(slope=lrelu_slope)
|
||||
self.depthwise = DepthWiseConv1d(
|
||||
inner_dim,
|
||||
inner_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=calc_same_padding(kernel_size)[0],
|
||||
)
|
||||
self.ln_2 = nn.GroupNorm(1, inner_dim)
|
||||
self.activation = nn.LeakyReLU(lrelu_slope)
|
||||
self.conv_2 = PointwiseConv1d(inner_dim, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T, C]`
|
||||
"""
|
||||
x = self.ln_1(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv_1(x)
|
||||
x = self.conv_act(x)
|
||||
x = self.depthwise(x)
|
||||
x = self.ln_2(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv_2(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConformerMultiHeadedSelfAttention(nn.Module):
|
||||
"""
|
||||
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
|
||||
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
|
||||
module to generalize better on different input length and the resulting encoder is more robust to the variance of
|
||||
the utterance length. Conformer use prenorm residual units with dropout which helps training
|
||||
and regularizing deeper models.
|
||||
Args:
|
||||
d_model (int): The dimension of model
|
||||
num_heads (int): The number of attention heads.
|
||||
dropout_p (float): probability of dropout
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_heads: int, dropout_p: float):
|
||||
super().__init__()
|
||||
self.attention = RelativeMultiHeadAttention(d_model=d_model, num_heads=num_heads)
|
||||
self.dropout = nn.Dropout(p=dropout_p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable
|
||||
encoding = encoding[:, : key.shape[1]]
|
||||
encoding = encoding.repeat(batch_size, 1, 1)
|
||||
outputs, attn = self.attention(query, key, value, pos_embedding=encoding, mask=mask)
|
||||
outputs = self.dropout(outputs)
|
||||
return outputs, attn
|
||||
|
||||
|
||||
class RelativeMultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention with relative positional encoding.
|
||||
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Args:
|
||||
d_model (int): The dimension of model
|
||||
num_heads (int): The number of attention heads.
|
||||
Inputs: query, key, value, pos_embedding, mask
|
||||
- **query** (batch, time, dim): Tensor containing query vector
|
||||
- **key** (batch, time, dim): Tensor containing key vector
|
||||
- **value** (batch, time, dim): Tensor containing value vector
|
||||
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
||||
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **outputs**: Tensor produces by relative multi head attention module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
num_heads: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
||||
self.d_model = d_model
|
||||
self.d_head = int(d_model / num_heads)
|
||||
self.num_heads = num_heads
|
||||
self.sqrt_dim = math.sqrt(d_model)
|
||||
|
||||
self.query_proj = nn.Linear(d_model, d_model)
|
||||
self.key_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.value_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
||||
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
||||
torch.nn.init.xavier_uniform_(self.u_bias)
|
||||
torch.nn.init.xavier_uniform_(self.v_bias)
|
||||
self.out_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
pos_embedding: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = query.shape[0]
|
||||
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
||||
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
||||
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
||||
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
||||
u_bias = self.u_bias.expand_as(query)
|
||||
v_bias = self.v_bias.expand_as(query)
|
||||
a = (query + u_bias).transpose(1, 2)
|
||||
content_score = a @ key.transpose(2, 3)
|
||||
b = (query + v_bias).transpose(1, 2)
|
||||
pos_score = b @ pos_embedding.permute(0, 2, 3, 1)
|
||||
pos_score = self._relative_shift(pos_score)
|
||||
|
||||
score = content_score + pos_score
|
||||
score = score * (1.0 / self.sqrt_dim)
|
||||
|
||||
score.masked_fill_(mask, -1e9)
|
||||
|
||||
attn = F.softmax(score, -1)
|
||||
|
||||
context = (attn @ value).transpose(1, 2)
|
||||
context = context.contiguous().view(batch_size, -1, self.d_model)
|
||||
|
||||
return self.out_proj(context), attn
|
||||
|
||||
def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use
|
||||
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
||||
zeros = torch.zeros((batch_size, num_heads, seq_length1, 1), device=pos_score.device)
|
||||
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
||||
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
||||
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
||||
return pos_score
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
split_size = self.num_units // self.num_heads
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
return out
|
|
@ -0,0 +1,670 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
||||
|
||||
|
||||
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
|
||||
pad = kernel_size // 2
|
||||
return (pad, pad - (kernel_size + 1) % 2)
|
||||
|
||||
|
||||
class ConvNorm(nn.Module):
|
||||
"""A 1-dimensional convolutional layer with optional weight normalization.
|
||||
|
||||
This layer wraps a 1D convolutional layer from PyTorch and applies
|
||||
optional weight normalization. The layer can be used in a similar way to
|
||||
the convolutional layers in PyTorch's `torch.nn` module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input signal.
|
||||
out_channels (int): The number of channels in the output signal.
|
||||
kernel_size (int, optional): The size of the convolving kernel.
|
||||
Defaults to 1.
|
||||
stride (int, optional): The stride of the convolution. Defaults to 1.
|
||||
padding (int, optional): Zero-padding added to both sides of the input.
|
||||
If `None`, the padding will be calculated so that the output has
|
||||
the same length as the input. Defaults to `None`.
|
||||
dilation (int, optional): Spacing between kernel elements. Defaults to 1.
|
||||
bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`.
|
||||
w_init_gain (str, optional): The weight initialization function to use.
|
||||
Can be either 'linear' or 'relu'. Defaults to 'linear'.
|
||||
use_weight_norm (bool, optional): If `True`, apply weight normalization
|
||||
to the convolutional weights. Defaults to `False`.
|
||||
|
||||
Shapes:
|
||||
- Input: :math:`[N, D, T]`
|
||||
|
||||
- Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
w_init_gain="linear",
|
||||
use_weight_norm=False,
|
||||
):
|
||||
super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments
|
||||
if padding is None:
|
||||
assert kernel_size % 2 == 1
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
self.use_weight_norm = use_weight_norm
|
||||
conv_fn = nn.Conv1d
|
||||
self.conv = conv_fn(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||
if self.use_weight_norm:
|
||||
self.conv = nn.utils.weight_norm(self.conv)
|
||||
|
||||
def forward(self, signal, mask=None):
|
||||
conv_signal = self.conv(signal)
|
||||
if mask is not None:
|
||||
# always re-zero output if mask is
|
||||
# available to match zero-padding
|
||||
conv_signal = conv_signal * mask
|
||||
return conv_signal
|
||||
|
||||
|
||||
class ConvLSTMLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
n_layers=2,
|
||||
n_channels=256,
|
||||
kernel_size=3,
|
||||
p_dropout=0.1,
|
||||
lstm_type="bilstm",
|
||||
use_linear=True,
|
||||
):
|
||||
super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments
|
||||
self.out_dim = out_dim
|
||||
self.lstm_type = lstm_type
|
||||
self.use_linear = use_linear
|
||||
self.dropout = nn.Dropout(p=p_dropout)
|
||||
|
||||
convolutions = []
|
||||
for i in range(n_layers):
|
||||
conv_layer = ConvNorm(
|
||||
in_dim if i == 0 else n_channels,
|
||||
n_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1,
|
||||
w_init_gain="relu",
|
||||
)
|
||||
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
|
||||
convolutions.append(conv_layer)
|
||||
|
||||
self.convolutions = nn.ModuleList(convolutions)
|
||||
|
||||
if not self.use_linear:
|
||||
n_channels = out_dim
|
||||
|
||||
if self.lstm_type != "":
|
||||
use_bilstm = False
|
||||
lstm_channels = n_channels
|
||||
if self.lstm_type == "bilstm":
|
||||
use_bilstm = True
|
||||
lstm_channels = int(n_channels // 2)
|
||||
|
||||
self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm)
|
||||
lstm_norm_fn_pntr = nn.utils.spectral_norm
|
||||
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
|
||||
if self.lstm_type == "bilstm":
|
||||
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
|
||||
|
||||
if self.use_linear:
|
||||
self.dense = nn.Linear(n_channels, out_dim)
|
||||
|
||||
def run_padded_sequence(self, context, lens):
|
||||
context_embedded = []
|
||||
for b_ind in range(context.size()[0]): # TODO: speed up
|
||||
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
|
||||
for conv in self.convolutions:
|
||||
curr_context = self.dropout(F.relu(conv(curr_context)))
|
||||
context_embedded.append(curr_context[0].transpose(0, 1))
|
||||
context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
|
||||
return context
|
||||
|
||||
def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use
|
||||
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
|
||||
unsort_ids = [0] * lens.size(0)
|
||||
for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate
|
||||
unsort_ids[ids_sorted[i]] = i
|
||||
lens_sorted = lens_sorted.long().cpu()
|
||||
|
||||
context = context[ids_sorted]
|
||||
context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True)
|
||||
context = fn(context)[0]
|
||||
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
|
||||
|
||||
# map back to original indices
|
||||
context = context[unsort_ids]
|
||||
return context
|
||||
|
||||
def forward(self, context, lens):
|
||||
if context.size()[0] > 1:
|
||||
context = self.run_padded_sequence(context, lens)
|
||||
# to B, D, T
|
||||
context = context.transpose(1, 2)
|
||||
else:
|
||||
for conv in self.convolutions:
|
||||
context = self.dropout(F.relu(conv(context)))
|
||||
|
||||
if self.lstm_type != "":
|
||||
context = context.transpose(1, 2)
|
||||
self.bilstm.flatten_parameters()
|
||||
if lens is not None:
|
||||
context = self.run_unsorted_inputs(self.bilstm, context, lens)
|
||||
else:
|
||||
context = self.bilstm(context)[0]
|
||||
context = context.transpose(1, 2)
|
||||
|
||||
x_hat = context
|
||||
if self.use_linear:
|
||||
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return x_hat
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class PointwiseConv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class BSConv1d(nn.Module):
|
||||
"""https://arxiv.org/pdf/2003.13549.pdf"""
|
||||
|
||||
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int):
|
||||
super().__init__()
|
||||
self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1)
|
||||
self.depthwise = nn.Conv1d(
|
||||
channels_out,
|
||||
channels_out,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=channels_out,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.pointwise(x)
|
||||
x2 = self.depthwise(x1)
|
||||
return x2
|
||||
|
||||
|
||||
class BSConv2d(nn.Module):
|
||||
"""https://arxiv.org/pdf/2003.13549.pdf"""
|
||||
|
||||
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int):
|
||||
super().__init__()
|
||||
self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1)
|
||||
self.depthwise = nn.Conv2d(
|
||||
channels_out,
|
||||
channels_out,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=channels_out,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.pointwise(x)
|
||||
x2 = self.depthwise(x1)
|
||||
return x2
|
||||
|
||||
|
||||
class Conv1dGLU(nn.Module):
|
||||
"""From DeepVoice 3"""
|
||||
|
||||
def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding)
|
||||
self.embedding_proj = nn.Linear(embedding_dim, d_model)
|
||||
self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0))
|
||||
self.softsign = torch.nn.Softsign()
|
||||
|
||||
def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute((0, 2, 1))
|
||||
residual = x
|
||||
x = self.conv(x)
|
||||
splitdim = 1
|
||||
a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
embeddings = self.embedding_proj(embeddings).unsqueeze(2)
|
||||
softsign = self.softsign(embeddings)
|
||||
softsign = softsign.expand_as(a)
|
||||
a = a + softsign
|
||||
x = a * torch.sigmoid(b)
|
||||
x = x + residual
|
||||
x = x * self.sqrt
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class ConvTransposed(nn.Module):
|
||||
"""
|
||||
A 1D convolutional transposed layer for PyTorch.
|
||||
This layer applies a 1D convolutional transpose operation to its input tensor,
|
||||
where the number of channels of the input tensor is the same as the number of channels of the output tensor.
|
||||
|
||||
Attributes:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
out_channels (int): The number of channels in the output tensor.
|
||||
kernel_size (int): The size of the convolutional kernel. Default: 1.
|
||||
padding (int): The number of padding elements to add to the input tensor. Default: 0.
|
||||
conv (BSConv1d): The 1D convolutional transpose layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
padding: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = BSConv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.contiguous().transpose(1, 2)
|
||||
x = self.conv(x)
|
||||
x = x.contiguous().transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseConvModule(nn.Module):
|
||||
def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3):
|
||||
super().__init__()
|
||||
padding = calc_same_padding(kernel_size)
|
||||
self.depthwise = nn.Conv1d(
|
||||
dim,
|
||||
dim * expansion,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding[0],
|
||||
groups=dim,
|
||||
)
|
||||
self.act = nn.LeakyReLU(lrelu_slope)
|
||||
self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0)
|
||||
self.ln = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = self.depthwise(x)
|
||||
x = self.act(x)
|
||||
x = self.out(x)
|
||||
x = x.permute((0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
class AddCoords(nn.Module):
|
||||
def __init__(self, rank: int, with_r: bool = False):
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.with_r = with_r
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.rank == 1:
|
||||
batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable
|
||||
xx_range = torch.arange(dim_x, dtype=torch.int32)
|
||||
xx_channel = xx_range[None, None, :]
|
||||
|
||||
xx_channel = xx_channel.float() / (dim_x - 1)
|
||||
xx_channel = xx_channel * 2 - 1
|
||||
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)
|
||||
|
||||
xx_channel = xx_channel.to(x.device)
|
||||
out = torch.cat([x, xx_channel], dim=1)
|
||||
|
||||
if self.with_r:
|
||||
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
|
||||
out = torch.cat([out, rr], dim=1)
|
||||
|
||||
elif self.rank == 2:
|
||||
batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape
|
||||
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
|
||||
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
|
||||
|
||||
xx_range = torch.arange(dim_y, dtype=torch.int32)
|
||||
yy_range = torch.arange(dim_x, dtype=torch.int32)
|
||||
xx_range = xx_range[None, None, :, None]
|
||||
yy_range = yy_range[None, None, :, None]
|
||||
|
||||
xx_channel = torch.matmul(xx_range, xx_ones)
|
||||
yy_channel = torch.matmul(yy_range, yy_ones)
|
||||
|
||||
# transpose y
|
||||
yy_channel = yy_channel.permute(0, 1, 3, 2)
|
||||
|
||||
xx_channel = xx_channel.float() / (dim_y - 1)
|
||||
yy_channel = yy_channel.float() / (dim_x - 1)
|
||||
|
||||
xx_channel = xx_channel * 2 - 1
|
||||
yy_channel = yy_channel * 2 - 1
|
||||
|
||||
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
|
||||
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
|
||||
|
||||
xx_channel = xx_channel.to(x.device)
|
||||
yy_channel = yy_channel.to(x.device)
|
||||
|
||||
out = torch.cat([x, xx_channel, yy_channel], dim=1)
|
||||
|
||||
if self.with_r:
|
||||
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
|
||||
out = torch.cat([out, rr], dim=1)
|
||||
|
||||
elif self.rank == 3:
|
||||
batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape
|
||||
xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32)
|
||||
yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32)
|
||||
zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32)
|
||||
|
||||
xy_range = torch.arange(dim_y, dtype=torch.int32)
|
||||
xy_range = xy_range[None, None, None, :, None]
|
||||
|
||||
yz_range = torch.arange(dim_z, dtype=torch.int32)
|
||||
yz_range = yz_range[None, None, None, :, None]
|
||||
|
||||
zx_range = torch.arange(dim_x, dtype=torch.int32)
|
||||
zx_range = zx_range[None, None, None, :, None]
|
||||
|
||||
xy_channel = torch.matmul(xy_range, xx_ones)
|
||||
xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)
|
||||
|
||||
yz_channel = torch.matmul(yz_range, yy_ones)
|
||||
yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
|
||||
yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)
|
||||
|
||||
zx_channel = torch.matmul(zx_range, zz_ones)
|
||||
zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
|
||||
zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)
|
||||
|
||||
xx_channel = xx_channel.to(x.device)
|
||||
yy_channel = yy_channel.to(x.device)
|
||||
zz_channel = zz_channel.to(x.device)
|
||||
out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)
|
||||
|
||||
if self.with_r:
|
||||
rr = torch.sqrt(
|
||||
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2)
|
||||
)
|
||||
out = torch.cat([out, rr], dim=1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CoordConv1d(nn.modules.conv.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
with_r: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
self.rank = 1
|
||||
self.addcoords = AddCoords(self.rank, with_r)
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels + self.rank + int(with_r),
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.addcoords(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class CoordConv2d(nn.modules.conv.Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
with_r: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
self.rank = 2
|
||||
self.addcoords = AddCoords(self.rank, with_r)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels + self.rank + int(with_r),
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.addcoords(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
stride,
|
||||
dilations=[1, 3, 9, 27],
|
||||
lReLU_slope=0.2,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = len(dilations)
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=len(dilations),
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
||||
)
|
||||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
2 * stride,
|
||||
stride=stride,
|
||||
padding=stride // 2 + stride % 2,
|
||||
output_padding=stride % 2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
for dilation in dilations:
|
||||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
conv_kernel_size,
|
||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
||||
dilation=dilation,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||
|
||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
for i, conv in enumerate(self.conv_blocks):
|
||||
output = conv(x) # (B, c_g, stride * L')
|
||||
|
||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||
|
||||
output = self.location_variable_convolution(
|
||||
output, k, b, hop_size=self.cond_hop_length
|
||||
) # (B, 2 * c_g, stride * L'): LVC
|
||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||
output[:, in_channels:, :]
|
||||
) # (B, c_g, stride * L'): GAU
|
||||
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||
for block in self.conv_blocks:
|
||||
nn.utils.remove_weight_norm(block[1])
|
|
@ -0,0 +1,261 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
|
||||
from TTS.tts.layers.delightful_tts.networks import STL
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
return mask
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
return torch.ceil(lens / stride).int()
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
"""
|
||||
Referance encoder for utterance and phoneme prosody encoders. Reference encoder
|
||||
made up of convolution and RNN layers.
|
||||
|
||||
Args:
|
||||
num_mels (int): Number of mel frames to produce.
|
||||
ref_enc_filters (list[int]): List of channel sizes for encoder layers.
|
||||
ref_enc_size (int): Size of the kernel for the conv layers.
|
||||
ref_enc_strides (List[int]): List of strides to use for conv layers.
|
||||
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, dim, time): Tensor containing mel vector
|
||||
- **lengths** (batch): Tensor containing the mel lengths.
|
||||
Returns:
|
||||
- **outputs** (batch, time, dim): Tensor produced by Reference Encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mels: int,
|
||||
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
||||
ref_enc_size: int,
|
||||
ref_enc_strides: List[Union[int, int, int, int, int]],
|
||||
ref_enc_gru_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
n_mel_channels = num_mels
|
||||
self.n_mel_channels = n_mel_channels
|
||||
K = len(ref_enc_filters)
|
||||
filters = [self.n_mel_channels] + ref_enc_filters
|
||||
strides = [1] + ref_enc_strides
|
||||
# Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
|
||||
convs = [
|
||||
CoordConv1d(
|
||||
in_channels=filters[0],
|
||||
out_channels=filters[0 + 1],
|
||||
kernel_size=ref_enc_size,
|
||||
stride=strides[0],
|
||||
padding=ref_enc_size // 2,
|
||||
with_r=True,
|
||||
)
|
||||
]
|
||||
convs2 = [
|
||||
nn.Conv1d(
|
||||
in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=ref_enc_size,
|
||||
stride=strides[i],
|
||||
padding=ref_enc_size // 2,
|
||||
)
|
||||
for i in range(1, K)
|
||||
]
|
||||
convs.extend(convs2)
|
||||
self.convs = nn.ModuleList(convs)
|
||||
|
||||
self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)])
|
||||
|
||||
self.gru = nn.GRU(
|
||||
input_size=ref_enc_filters[-1],
|
||||
hidden_size=ref_enc_gru_size,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
inputs --- [N, n_mels, timesteps]
|
||||
outputs --- [N, E//2]
|
||||
"""
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
|
||||
x = x.masked_fill(mel_masks, 0)
|
||||
for conv, norm in zip(self.convs, self.norms):
|
||||
x = conv(x)
|
||||
x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||
x = norm(x)
|
||||
|
||||
for _ in range(2):
|
||||
mel_lens = stride_lens(mel_lens)
|
||||
|
||||
mel_masks = get_mask_from_lengths(mel_lens)
|
||||
|
||||
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
|
||||
x = x.permute((0, 2, 1))
|
||||
x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False)
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2]
|
||||
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||
|
||||
return x, memory, mel_masks
|
||||
|
||||
def calculate_channels( # pylint: disable=no-self-use
|
||||
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
|
||||
) -> int:
|
||||
for _ in range(n_convs):
|
||||
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||
return L
|
||||
|
||||
|
||||
class UtteranceLevelProsodyEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_mels: int,
|
||||
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
||||
ref_enc_size: int,
|
||||
ref_enc_strides: List[Union[int, int, int, int, int]],
|
||||
ref_enc_gru_size: int,
|
||||
dropout: float,
|
||||
n_hidden: int,
|
||||
bottleneck_size_u: int,
|
||||
token_num: int,
|
||||
):
|
||||
"""
|
||||
Encoder to extract prosody from utterance. it is made up of a reference encoder
|
||||
with a couple of linear layers and style token layer with dropout.
|
||||
|
||||
Args:
|
||||
num_mels (int): Number of mel frames to produce.
|
||||
ref_enc_filters (list[int]): List of channel sizes for ref encoder layers.
|
||||
ref_enc_size (int): Size of the kernel for the ref encoder conv layers.
|
||||
ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers.
|
||||
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
|
||||
dropout (float): Probability of dropout.
|
||||
n_hidden (int): Size of hidden layers.
|
||||
bottleneck_size_u (int): Size of the bottle neck layer.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, dim, time): Tensor containing mel vector
|
||||
- **lengths** (batch): Tensor containing the mel lengths.
|
||||
Returns:
|
||||
- **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.E = n_hidden
|
||||
self.d_q = self.d_k = n_hidden
|
||||
bottleneck_size = bottleneck_size_u
|
||||
|
||||
self.encoder = ReferenceEncoder(
|
||||
ref_enc_filters=ref_enc_filters,
|
||||
ref_enc_gru_size=ref_enc_gru_size,
|
||||
ref_enc_size=ref_enc_size,
|
||||
ref_enc_strides=ref_enc_strides,
|
||||
num_mels=num_mels,
|
||||
)
|
||||
self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2)
|
||||
self.stl = STL(n_hidden=n_hidden, token_num=token_num)
|
||||
self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
mels: :math: `[B, C, T]`
|
||||
mel_lens: :math: `[B]`
|
||||
|
||||
out --- [N, seq_len, E]
|
||||
"""
|
||||
_, embedded_prosody, _ = self.encoder(mels, mel_lens)
|
||||
|
||||
# Bottleneck
|
||||
embedded_prosody = self.encoder_prj(embedded_prosody)
|
||||
|
||||
# Style Token
|
||||
out = self.encoder_bottleneck(self.stl(embedded_prosody))
|
||||
out = self.dropout(out)
|
||||
|
||||
out = out.view((-1, 1, out.shape[3]))
|
||||
return out
|
||||
|
||||
|
||||
class PhonemeLevelProsodyEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_mels: int,
|
||||
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
||||
ref_enc_size: int,
|
||||
ref_enc_strides: List[Union[int, int, int, int, int]],
|
||||
ref_enc_gru_size: int,
|
||||
dropout: float,
|
||||
n_hidden: int,
|
||||
n_heads: int,
|
||||
bottleneck_size_p: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.E = n_hidden
|
||||
self.d_q = self.d_k = n_hidden
|
||||
bottleneck_size = bottleneck_size_p
|
||||
|
||||
self.encoder = ReferenceEncoder(
|
||||
ref_enc_filters=ref_enc_filters,
|
||||
ref_enc_gru_size=ref_enc_gru_size,
|
||||
ref_enc_size=ref_enc_size,
|
||||
ref_enc_strides=ref_enc_strides,
|
||||
num_mels=num_mels,
|
||||
)
|
||||
self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden)
|
||||
self.attention = ConformerMultiHeadedSelfAttention(
|
||||
d_model=n_hidden,
|
||||
num_heads=n_heads,
|
||||
dropout_p=dropout,
|
||||
)
|
||||
self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
src_mask: torch.Tensor,
|
||||
mels: torch.Tensor,
|
||||
mel_lens: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x --- [N, seq_len, encoder_embedding_dim]
|
||||
mels --- [N, Ty/r, n_mels*r], r=1
|
||||
out --- [N, seq_len, bottleneck_size]
|
||||
attn --- [N, seq_len, ref_len], Ty/r = ref_len
|
||||
"""
|
||||
embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens)
|
||||
|
||||
# Bottleneck
|
||||
embedded_prosody = self.encoder_prj(embedded_prosody)
|
||||
|
||||
attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1))
|
||||
x, _ = self.attention(
|
||||
query=x,
|
||||
key=embedded_prosody,
|
||||
value=embedded_prosody,
|
||||
mask=attn_mask,
|
||||
encoding=encoding,
|
||||
)
|
||||
x = self.encoder_bottleneck(x)
|
||||
x = x.masked_fill(src_mask.unsqueeze(-1), 0.0)
|
||||
return x
|
|
@ -0,0 +1,82 @@
|
|||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
|
||||
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
||||
from TTS.tts.utils.helpers import average_over_durations
|
||||
|
||||
|
||||
class EnergyAdaptor(nn.Module): # pylint: disable=abstract-method
|
||||
"""Variance Adaptor with an added 1D conv layer. Used to
|
||||
get energy embeddings.
|
||||
|
||||
Args:
|
||||
channels_in (int): Number of in channels for conv layers.
|
||||
channels_out (int): Number of out channels.
|
||||
kernel_size (int): Size the kernel for the conv layers.
|
||||
dropout (float): Probability of dropout.
|
||||
lrelu_slope (float): Slope for the leaky relu.
|
||||
emb_kernel_size (int): Size the kernel for the pitch embedding.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time1, dim): Tensor containing input vector
|
||||
- **target** (batch, 1, time2): Tensor containing the energy target
|
||||
- **dr** (batch, time1): Tensor containing aligner durations vector
|
||||
- **mask** (batch, time1): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **energy prediction** (batch, 1, time1): Tensor produced by energy predictor
|
||||
- **energy embedding** (batch, channels, time1): Tensor produced energy adaptor
|
||||
- **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels_in: int,
|
||||
channels_hidden: int,
|
||||
channels_out: int,
|
||||
kernel_size: int,
|
||||
dropout: float,
|
||||
lrelu_slope: float,
|
||||
emb_kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.energy_predictor = VariancePredictor(
|
||||
channels_in=channels_in,
|
||||
channels=channels_hidden,
|
||||
channels_out=channels_out,
|
||||
kernel_size=kernel_size,
|
||||
p_dropout=dropout,
|
||||
lrelu_slope=lrelu_slope,
|
||||
)
|
||||
self.energy_emb = nn.Conv1d(
|
||||
1,
|
||||
channels_hidden,
|
||||
kernel_size=emb_kernel_size,
|
||||
padding=int((emb_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
def get_energy_embedding_train(
|
||||
self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T_src, C]`
|
||||
target: :math: `[B, 1, T_max2]`
|
||||
dr: :math: `[B, T_src]`
|
||||
mask: :math: `[B, T_src]`
|
||||
"""
|
||||
energy_pred = self.energy_predictor(x, mask)
|
||||
energy_pred.unsqueeze_(1)
|
||||
avg_energy_target = average_over_durations(target, dr)
|
||||
energy_emb = self.energy_emb(avg_energy_target)
|
||||
return energy_pred, avg_energy_target, energy_emb
|
||||
|
||||
def get_energy_embedding(self, x: torch.Tensor, mask: torch.Tensor, energy_transform: Callable) -> torch.Tensor:
|
||||
energy_pred = self.energy_predictor(x, mask)
|
||||
energy_pred.unsqueeze_(1)
|
||||
if energy_transform is not None:
|
||||
energy_pred = energy_transform(energy_pred, (~mask).sum(dim=(1, 2)), self.pitch_mean, self.pitch_std)
|
||||
energy_emb_pred = self.energy_emb(energy_pred)
|
||||
return energy_emb_pred, energy_pred
|
|
@ -0,0 +1,125 @@
|
|||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
|
||||
|
||||
class KernelPredictor(nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions
|
||||
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int): number of layers
|
||||
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
for _ in range(3):
|
||||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
c = self.input_conv(c)
|
||||
for residual_conv in self.residual_convs:
|
||||
residual_conv.to(c.device)
|
||||
c = c + residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
kernels = k.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length,
|
||||
)
|
||||
bias = b.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length,
|
||||
)
|
||||
|
||||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
||||
nn.utils.remove_weight_norm(self.bias_conv)
|
||||
for block in self.residual_convs:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
nn.utils.remove_weight_norm(block[3])
|
|
@ -0,0 +1,219 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import ConvNorm
|
||||
|
||||
|
||||
def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor:
|
||||
assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..."
|
||||
# Kaiming initialization
|
||||
return torch.randn(shape) * np.sqrt(2 / shape[1])
|
||||
|
||||
|
||||
def positional_encoding(d_model: int, length: int, device: torch.device) -> torch.Tensor:
|
||||
pe = torch.zeros(length, d_model, device=device)
|
||||
position = torch.arange(0, length, dtype=torch.float, device=device).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
return pe
|
||||
|
||||
|
||||
class BottleneckLayer(nn.Module):
|
||||
"""
|
||||
Bottleneck layer for reducing the dimensionality of a tensor.
|
||||
|
||||
Args:
|
||||
in_dim: The number of input dimensions.
|
||||
reduction_factor: The factor by which to reduce the number of dimensions.
|
||||
norm: The normalization method to use. Can be "weightnorm" or "instancenorm".
|
||||
non_linearity: The non-linearity to use. Can be "relu" or "leakyrelu".
|
||||
kernel_size: The size of the convolutional kernel.
|
||||
use_partial_padding: Whether to use partial padding with the convolutional kernel.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`[N, in_dim]` where `N` is the batch size and `in_dim` is the number of input dimensions.
|
||||
|
||||
- Output: :math:`[N, out_dim]` where `out_dim` is the number of output dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
reduction_factor,
|
||||
norm="weightnorm",
|
||||
non_linearity="relu",
|
||||
kernel_size=3,
|
||||
use_partial_padding=False, # pylint: disable=unused-argument
|
||||
):
|
||||
super(BottleneckLayer, self).__init__() # pylint: disable=super-with-arguments
|
||||
|
||||
self.reduction_factor = reduction_factor
|
||||
reduced_dim = int(in_dim / reduction_factor)
|
||||
self.out_dim = reduced_dim
|
||||
if self.reduction_factor > 1:
|
||||
fn = ConvNorm(in_dim, reduced_dim, kernel_size=kernel_size, use_weight_norm=(norm == "weightnorm"))
|
||||
if norm == "instancenorm":
|
||||
fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True))
|
||||
|
||||
self.projection_fn = fn
|
||||
self.non_linearity = nn.ReLU()
|
||||
if non_linearity == "leakyrelu":
|
||||
self.non_linearity = nn.LeakyReLU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.reduction_factor > 1:
|
||||
x = self.projection_fn(x)
|
||||
x = self.non_linearity(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUActivation(nn.Module):
|
||||
"""Class that implements the Gated Linear Unit (GLU) activation function.
|
||||
|
||||
The GLU activation function is a variant of the Leaky ReLU activation function,
|
||||
where the output of the activation function is gated by an input tensor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, slope: float):
|
||||
super().__init__()
|
||||
self.lrelu = nn.LeakyReLU(slope)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out, gate = x.chunk(2, dim=1)
|
||||
x = out * self.lrelu(gate)
|
||||
return x
|
||||
|
||||
|
||||
class StyleEmbedAttention(nn.Module):
|
||||
def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor:
|
||||
values = self.W_value(key_soft)
|
||||
split_size = self.num_units // self.num_heads
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0)
|
||||
|
||||
out_soft = scores_soft = None
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key_soft) # [N, T_k, num_units]
|
||||
|
||||
# [h, N, T_q, num_units/h]
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)
|
||||
# [h, N, T_k, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)
|
||||
# [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores_soft = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores_soft = scores_soft / (self.key_dim**0.5)
|
||||
scores_soft = F.softmax(scores_soft, dim=3)
|
||||
|
||||
# out = score * V
|
||||
# [h, N, T_q, num_units/h]
|
||||
out_soft = torch.matmul(scores_soft, values)
|
||||
out_soft = torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out_soft # , scores_soft
|
||||
|
||||
|
||||
class EmbeddingPadded(nn.Module):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
super().__init__()
|
||||
padding_mult = torch.ones((num_embeddings, 1), dtype=torch.int64)
|
||||
padding_mult[padding_idx] = 0
|
||||
self.register_buffer("padding_mult", padding_mult)
|
||||
self.embeddings = nn.parameter.Parameter(initialize_embeddings((num_embeddings, embedding_dim)))
|
||||
|
||||
def forward(self, idx: torch.Tensor) -> torch.Tensor:
|
||||
embeddings_zeroed = self.embeddings * self.padding_mult
|
||||
x = F.embedding(idx, embeddings_zeroed)
|
||||
return x
|
||||
|
||||
|
||||
class EmbeddingProjBlock(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(embedding_dim, embedding_dim),
|
||||
nn.LeakyReLU(0.3),
|
||||
nn.Linear(embedding_dim, embedding_dim),
|
||||
nn.LeakyReLU(0.3),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
res = x
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = x + res
|
||||
return x
|
||||
|
||||
|
||||
class LinearNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features, out_features, bias)
|
||||
|
||||
nn.init.xavier_uniform_(self.linear.weight)
|
||||
if bias:
|
||||
nn.init.constant_(self.linear.bias, 0.0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class STL(nn.Module):
|
||||
"""
|
||||
A PyTorch module for the Style Token Layer (STL) as described in
|
||||
"A Style-Based Generator Architecture for Generative Adversarial Networks"
|
||||
(https://arxiv.org/abs/1812.04948)
|
||||
|
||||
The STL applies a multi-headed attention mechanism over the learned style tokens,
|
||||
using the text input as the query and the style tokens as the keys and values.
|
||||
The output of the attention mechanism is used as the text's style embedding.
|
||||
|
||||
Args:
|
||||
token_num (int): The number of style tokens.
|
||||
n_hidden (int): Number of hidden dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, n_hidden: int, token_num: int):
|
||||
super(STL, self).__init__() # pylint: disable=super-with-arguments
|
||||
|
||||
num_heads = 1
|
||||
E = n_hidden
|
||||
self.token_num = token_num
|
||||
self.embed = nn.Parameter(torch.FloatTensor(self.token_num, E // num_heads))
|
||||
d_q = E // 2
|
||||
d_k = E // num_heads
|
||||
self.attention = StyleEmbedAttention(query_dim=d_q, key_dim=d_k, num_units=E, num_heads=num_heads)
|
||||
|
||||
torch.nn.init.normal_(self.embed, mean=0, std=0.5)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
N = x.size(0)
|
||||
query = x.unsqueeze(1) # [N, 1, E//2]
|
||||
|
||||
keys_soft = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
|
||||
# Weighted sum
|
||||
emotion_embed_soft = self.attention(query, keys_soft)
|
||||
|
||||
return emotion_embed_soft
|
|
@ -0,0 +1,65 @@
|
|||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed
|
||||
|
||||
|
||||
class PhonemeProsodyPredictor(nn.Module):
|
||||
"""Non-parallel Prosody Predictor inspired by: https://arxiv.org/pdf/2102.00851.pdf
|
||||
It consists of 2 layers of 1D convolutions each followed by a relu activation, layer norm
|
||||
and dropout, then finally a linear layer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): Size of hidden channels.
|
||||
kernel_size (int): Kernel size for the conv layers.
|
||||
dropout: (float): Probability of dropout.
|
||||
bottleneck_size (int): bottleneck size for last linear layer.
|
||||
lrelu_slope (float): Slope of the leaky relu.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
kernel_size: int,
|
||||
dropout: float,
|
||||
bottleneck_size: int,
|
||||
lrelu_slope: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = hidden_size
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvTransposed(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
nn.LeakyReLU(lrelu_slope),
|
||||
nn.LayerNorm(self.d_model),
|
||||
nn.Dropout(dropout),
|
||||
ConvTransposed(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
nn.LeakyReLU(lrelu_slope),
|
||||
nn.LayerNorm(self.d_model),
|
||||
nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.predictor_bottleneck = nn.Linear(self.d_model, bottleneck_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T, D]`
|
||||
mask: :math: `[B, T]`
|
||||
"""
|
||||
mask = mask.unsqueeze(2)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = x.masked_fill(mask, 0.0)
|
||||
x = self.predictor_bottleneck(x)
|
||||
return x
|
|
@ -0,0 +1,88 @@
|
|||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
|
||||
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
||||
from TTS.tts.utils.helpers import average_over_durations
|
||||
|
||||
|
||||
class PitchAdaptor(nn.Module): # pylint: disable=abstract-method
|
||||
"""Module to get pitch embeddings via pitch predictor
|
||||
|
||||
Args:
|
||||
n_input (int): Number of pitch predictor input channels.
|
||||
n_hidden (int): Number of pitch predictor hidden channels.
|
||||
n_out (int): Number of pitch predictor out channels.
|
||||
kernel size (int): Size of the kernel for conv layers.
|
||||
emb_kernel_size (int): Size the kernel for the pitch embedding.
|
||||
p_dropout (float): Probability of dropout.
|
||||
lrelu_slope (float): Slope for the leaky relu.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time1, dim): Tensor containing input vector
|
||||
- **target** (batch, 1, time2): Tensor containing the pitch target
|
||||
- **dr** (batch, time1): Tensor containing aligner durations vector
|
||||
- **mask** (batch, time1): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **pitch prediction** (batch, 1, time1): Tensor produced by pitch predictor
|
||||
- **pitch embedding** (batch, channels, time1): Tensor produced pitch pitch adaptor
|
||||
- **average pitch target(train only)** (batch, 1, time1): Tensor produced after averaging over durations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_input: int,
|
||||
n_hidden: int,
|
||||
n_out: int,
|
||||
kernel_size: int,
|
||||
emb_kernel_size: int,
|
||||
p_dropout: float,
|
||||
lrelu_slope: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.pitch_predictor = VariancePredictor(
|
||||
channels_in=n_input,
|
||||
channels=n_hidden,
|
||||
channels_out=n_out,
|
||||
kernel_size=kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
lrelu_slope=lrelu_slope,
|
||||
)
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
n_input,
|
||||
kernel_size=emb_kernel_size,
|
||||
padding=int((emb_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
def get_pitch_embedding_train(
|
||||
self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T_src, C]`
|
||||
target: :math: `[B, 1, T_max2]`
|
||||
dr: :math: `[B, T_src]`
|
||||
mask: :math: `[B, T_src]`
|
||||
"""
|
||||
pitch_pred = self.pitch_predictor(x, mask) # [B, T_src, C_hidden], [B, T_src] --> [B, T_src]
|
||||
pitch_pred.unsqueeze_(1) # --> [B, 1, T_src]
|
||||
avg_pitch_target = average_over_durations(target, dr) # [B, 1, T_mel], [B, T_src] --> [B, 1, T_src]
|
||||
pitch_emb = self.pitch_emb(avg_pitch_target) # [B, 1, T_src] --> [B, C_hidden, T_src]
|
||||
return pitch_pred, avg_pitch_target, pitch_emb
|
||||
|
||||
def get_pitch_embedding(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
pitch_transform: Callable,
|
||||
pitch_mean: torch.Tensor,
|
||||
pitch_std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
pitch_pred = self.pitch_predictor(x, mask)
|
||||
if pitch_transform is not None:
|
||||
pitch_pred = pitch_transform(pitch_pred, (~mask).sum(), pitch_mean, pitch_std)
|
||||
pitch_pred.unsqueeze_(1)
|
||||
pitch_emb_pred = self.pitch_emb(pitch_pred)
|
||||
return pitch_emb_pred, pitch_pred
|
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
|
||||
from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed
|
||||
|
||||
|
||||
class VariancePredictor(nn.Module):
|
||||
"""
|
||||
Network is 2-layer 1D convolutions with leaky relu activation and then
|
||||
followed by layer normalization then a dropout layer and finally an
|
||||
extra linear layer to project the hidden states into the output sequence.
|
||||
|
||||
Args:
|
||||
channels_in (int): Number of in channels for conv layers.
|
||||
channels_out (int): Number of out channels for the last linear layer.
|
||||
kernel_size (int): Size the kernel for the conv layers.
|
||||
p_dropout (float): Probability of dropout.
|
||||
lrelu_slope (float): Slope for the leaky relu.
|
||||
|
||||
Inputs: inputs, mask
|
||||
- **inputs** (batch, time, dim): Tensor containing input vector
|
||||
- **mask** (batch, time): Tensor containing indices to be masked
|
||||
Returns:
|
||||
- **outputs** (batch, time): Tensor produced by last linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels_in: int, channels: int, channels_out: int, kernel_size: int, p_dropout: float, lrelu_slope: float
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvTransposed(
|
||||
channels_in,
|
||||
channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
nn.LeakyReLU(lrelu_slope),
|
||||
nn.LayerNorm(channels),
|
||||
nn.Dropout(p_dropout),
|
||||
ConvTransposed(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
nn.LeakyReLU(lrelu_slope),
|
||||
nn.LayerNorm(channels),
|
||||
nn.Dropout(p_dropout),
|
||||
]
|
||||
)
|
||||
|
||||
self.linear_layer = nn.Linear(channels, channels_out)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Shapes:
|
||||
x: :math: `[B, T_src, C]`
|
||||
mask: :math: `[B, T_src]`
|
||||
"""
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = self.linear_layer(x)
|
||||
x = x.squeeze(-1)
|
||||
x = x.masked_fill(mask, 0.0)
|
||||
return x
|
|
@ -57,6 +57,15 @@ class AlignmentNetwork(torch.nn.Module):
|
|||
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||
)
|
||||
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"))
|
||||
torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"))
|
||||
torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu"))
|
||||
torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear"))
|
||||
torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear"))
|
||||
|
||||
def forward(
|
||||
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
|
||||
) -> Tuple[torch.tensor, torch.tensor]:
|
||||
|
@ -75,7 +84,9 @@ class AlignmentNetwork(torch.nn.Module):
|
|||
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
|
||||
if attn_prior is not None:
|
||||
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8)
|
||||
|
||||
if mask is not None:
|
||||
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||
|
||||
attn = self.softmax(attn_logp)
|
||||
return attn, attn_logp
|
||||
|
|
|
@ -214,6 +214,7 @@ class Bark(BaseTTS):
|
|||
as latents used at inference.
|
||||
|
||||
"""
|
||||
speaker_id = "random" if speaker_id is None else speaker_id
|
||||
voice_dirs = self._set_voice_dirs(voice_dirs)
|
||||
history_prompt = load_voice(self, speaker_id, voice_dirs)
|
||||
outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs)
|
||||
|
|
|
@ -439,3 +439,21 @@ class BaseTTS(BaseTrainerModel):
|
|||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
|
||||
|
||||
class BaseTTSE2E(BaseTTS):
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
self.config = config
|
||||
if "Config" in config.__class__.__name__:
|
||||
num_chars = (
|
||||
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
)
|
||||
self.config.model_args.num_chars = num_chars
|
||||
self.config.num_chars = num_chars
|
||||
self.args = config.model_args
|
||||
self.args.num_chars = num_chars
|
||||
elif "Args" in config.__class__.__name__:
|
||||
self.args = config
|
||||
self.args.num_chars = self.args.num_chars
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -513,9 +513,13 @@ class Tortoise(BaseTTS):
|
|||
as latents used at inference.
|
||||
|
||||
"""
|
||||
|
||||
speaker_id = "random" if speaker_id is None else speaker_id
|
||||
|
||||
if voice_dirs is not None:
|
||||
voice_dirs = [voice_dirs]
|
||||
voice_samples, conditioning_latents = load_voice(speaker_id, voice_dirs)
|
||||
|
||||
else:
|
||||
voice_samples, conditioning_latents = load_voice(speaker_id)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import betabinom
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
|
@ -233,3 +234,25 @@ def maximum_path_numpy(value, mask, max_neg_val=None):
|
|||
path = path * mask.astype(np.float32)
|
||||
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||
return path
|
||||
|
||||
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
|
||||
P, M = phoneme_count, mel_count
|
||||
x = np.arange(0, P)
|
||||
mel_text_probs = []
|
||||
for i in range(1, M + 1):
|
||||
a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
|
||||
rv = betabinom(P, a, b)
|
||||
mel_i_prob = rv.pmf(x)
|
||||
mel_text_probs.append(mel_i_prob)
|
||||
return np.array(mel_text_probs)
|
||||
|
||||
|
||||
def compute_attn_prior(x_len, y_len, scaling_factor=1.0):
|
||||
"""Compute attention priors for the alignment network."""
|
||||
attn_prior = beta_binomial_prior_distribution(
|
||||
x_len,
|
||||
y_len,
|
||||
scaling_factor,
|
||||
)
|
||||
return attn_prior # [y_len, x_len]
|
||||
|
|
|
@ -361,12 +361,12 @@ class Synthesizer(object):
|
|||
if not reference_wav: # not voice conversion
|
||||
for sen in sens:
|
||||
if hasattr(self.tts_model, "synthesize"):
|
||||
sp_name = "random" if speaker_name is None else speaker_name
|
||||
outputs = self.tts_model.synthesize(
|
||||
text=sen,
|
||||
config=self.tts_config,
|
||||
speaker_id=sp_name,
|
||||
speaker_id=speaker_name,
|
||||
voice_dirs=self.voice_dir,
|
||||
d_vector=speaker_embedding,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.delightful_tts import DelightfulTTS, DelightfulTtsArgs, VocoderConfig
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
data_path = ""
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
dataset_config = BaseDatasetConfig(
|
||||
dataset_name="ljspeech", formatter="ljspeech", meta_file_train="metadata.csv", path=data_path
|
||||
)
|
||||
|
||||
audio_config = DelightfulTtsAudioConfig()
|
||||
model_args = DelightfulTtsArgs()
|
||||
|
||||
vocoder_config = VocoderConfig()
|
||||
|
||||
delightful_tts_config = DelightfulTTSConfig(
|
||||
run_name="delightful_tts_ljspeech",
|
||||
run_description="Train like in delightful tts paper.",
|
||||
model_args=model_args,
|
||||
audio=audio_config,
|
||||
vocoder=vocoder_config,
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=10,
|
||||
num_eval_loader_workers=10,
|
||||
precompute_num_workers=10,
|
||||
batch_group_size=2,
|
||||
compute_input_seq_cache=True,
|
||||
compute_f0=True,
|
||||
f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
start_by_longest=False,
|
||||
eval_split_size=0.1,
|
||||
binary_align_loss_alpha=0.0,
|
||||
use_attn_priors=False,
|
||||
lr_gen=4e-1,
|
||||
lr=4e-1,
|
||||
lr_disc=4e-1,
|
||||
max_text_len=130,
|
||||
)
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(delightful_tts_config)
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
model = DelightfulTTS(ap=ap, config=config, tokenizer=tokenizer, speaker_manager=None)
|
||||
|
||||
trainer = Trainer(
|
||||
TrainerArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
|
||||
trainer.fit()
|
|
@ -0,0 +1,86 @@
|
|||
import os
|
||||
|
||||
from clearml import Task
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTTSE2e, VocoderConfig
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
task = Task.init(project_name="delightful-tts", task_name="vctk")
|
||||
data_path = "/raid/datasets/vctk_v092_48khz_removed_silence_silero_vad"
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
dataset_config = BaseDatasetConfig(dataset_name="vctk", meta_file_train="", path=data_path, language="en-us")
|
||||
|
||||
audio_config = DelightfulTtsAudioConfig()
|
||||
|
||||
model_args = DelightfulTtsArgs()
|
||||
|
||||
vocoder_config = VocoderConfig()
|
||||
|
||||
something_tts_config = DelightfulTTSConfig(
|
||||
run_name="delightful_tts_e2e_ljspeech",
|
||||
run_description="Train like in delightful tts paper.",
|
||||
model_args=model_args,
|
||||
audio=audio_config,
|
||||
vocoder=vocoder_config,
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=10,
|
||||
num_eval_loader_workers=10,
|
||||
precompute_num_workers=40,
|
||||
compute_input_seq_cache=True,
|
||||
compute_f0=True,
|
||||
f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
start_by_longest=True,
|
||||
binary_align_loss_alpha=0.0,
|
||||
use_attn_priors=False,
|
||||
max_text_len=60,
|
||||
steps_to_start_discriminator=10000,
|
||||
)
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(something_tts_config)
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
|
||||
model = DelightfulTTSE2e(
|
||||
ap=ap, config=config, tokenizer=tokenizer, speaker_manager=speaker_manager, emotion_manager=None
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||
)
|
||||
|
||||
trainer.fit()
|
|
@ -0,0 +1,98 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
audio_config = DelightfulTtsAudioConfig()
|
||||
model_args = DelightfulTtsArgs(
|
||||
use_speaker_embedding=False, d_vector_dim=256, use_d_vector_file=True, speaker_embedding_channels=256
|
||||
)
|
||||
|
||||
vocoder_config = VocoderConfig()
|
||||
|
||||
config = DelightfulTTSConfig(
|
||||
model_args=model_args,
|
||||
audio=audio_config,
|
||||
vocoder=vocoder_config,
|
||||
batch_size=2,
|
||||
eval_batch_size=8,
|
||||
compute_f0=True,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
binary_align_loss_alpha=0.0,
|
||||
use_attn_priors=False,
|
||||
test_sentences=["Be a voice, not an echo."],
|
||||
output_path=output_path,
|
||||
use_speaker_embedding=False,
|
||||
use_d_vector_file=True,
|
||||
d_vector_file="tests/data/ljspeech/speakers.json",
|
||||
d_vector_dim=256,
|
||||
speaker_embedding_channels=256,
|
||||
)
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = False
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = config.d_vector_file
|
||||
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == 0
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --speakers_file_path {continue_speakers_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/")
|
|
@ -0,0 +1,92 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
audio_config = DelightfulTtsAudioConfig()
|
||||
model_args = DelightfulTtsArgs(use_speaker_embedding=False)
|
||||
|
||||
vocoder_config = VocoderConfig()
|
||||
|
||||
config = DelightfulTTSConfig(
|
||||
model_args=model_args,
|
||||
audio=audio_config,
|
||||
vocoder=vocoder_config,
|
||||
batch_size=2,
|
||||
eval_batch_size=8,
|
||||
compute_f0=True,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
binary_align_loss_alpha=0.0,
|
||||
use_attn_priors=False,
|
||||
test_sentences=["Be a voice, not an echo."],
|
||||
output_path=output_path,
|
||||
num_speakers=4,
|
||||
use_speaker_embedding=True,
|
||||
)
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.model_args.use_d_vector_file = False
|
||||
config.model_args.d_vector_file = None
|
||||
config.model_args.d_vector_dim = 256
|
||||
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.dataset_name ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech"
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == 0
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/")
|
|
@ -0,0 +1,91 @@
|
|||
import torch
|
||||
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
|
||||
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
|
||||
from TTS.tts.utils.helpers import rand_segments
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args = DelightfulTtsArgs()
|
||||
v_args = VocoderConfig()
|
||||
|
||||
|
||||
config = DelightfulTTSConfig(
|
||||
model_args=args,
|
||||
# compute_f0=True,
|
||||
# f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
# phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
)
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
|
||||
def test_acoustic_model():
|
||||
dummy_tokens = torch.rand((1, 41)).long().to(device)
|
||||
dummy_text_lens = torch.tensor([41]).to(device)
|
||||
dummy_spec = torch.rand((1, 100, 207)).to(device)
|
||||
dummy_spec_lens = torch.tensor([207]).to(device)
|
||||
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
|
||||
dummy_energy = torch.rand((1, 1, 207)).long().to(device)
|
||||
|
||||
args.out_channels = 100
|
||||
args.num_mels = 100
|
||||
|
||||
acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
|
||||
|
||||
output = acoustic_model(
|
||||
tokens=dummy_tokens,
|
||||
src_lens=dummy_text_lens,
|
||||
mel_lens=dummy_spec_lens,
|
||||
mels=dummy_spec,
|
||||
pitches=dummy_pitch,
|
||||
energies=dummy_energy,
|
||||
attn_priors=None,
|
||||
d_vectors=None,
|
||||
speaker_idx=None,
|
||||
)
|
||||
assert list(output["model_outputs"].shape) == [1, 207, 100]
|
||||
output["model_outputs"].sum().backward()
|
||||
|
||||
|
||||
def test_hifi_decoder():
|
||||
dummy_input = torch.rand((1, 207, 100)).to(device)
|
||||
dummy_text_lens = torch.tensor([41]).to(device)
|
||||
dummy_spec = torch.rand((1, 100, 207)).to(device)
|
||||
dummy_spec_lens = torch.tensor([207]).to(device)
|
||||
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
|
||||
dummy_energy = torch.rand((1, 1, 207)).long().to(device)
|
||||
|
||||
waveform_decoder = HifiganGenerator(
|
||||
100,
|
||||
1,
|
||||
v_args.resblock_type_decoder,
|
||||
v_args.resblock_dilation_sizes_decoder,
|
||||
v_args.resblock_kernel_sizes_decoder,
|
||||
v_args.upsample_kernel_sizes_decoder,
|
||||
v_args.upsample_initial_channel_decoder,
|
||||
v_args.upsample_rates_decoder,
|
||||
inference_padding=0,
|
||||
cond_channels=0,
|
||||
conv_pre_weight_norm=False,
|
||||
conv_post_weight_norm=False,
|
||||
conv_post_bias=False,
|
||||
).to(device)
|
||||
|
||||
vocoder_input_slices, slice_ids = rand_segments( # pylint: disable=unused-variable
|
||||
x=dummy_input.transpose(1, 2),
|
||||
x_lengths=dummy_spec_lens,
|
||||
segment_size=32,
|
||||
let_short_samples=True,
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
outputs = waveform_decoder(x=vocoder_input_slices.detach())
|
||||
assert list(outputs.shape) == [1, 1, 8192]
|
||||
outputs.sum().backward()
|
|
@ -0,0 +1,97 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
|
||||
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
audio_config = DelightfulTtsAudioConfig()
|
||||
model_args = DelightfulTtsArgs()
|
||||
|
||||
vocoder_config = VocoderConfig()
|
||||
|
||||
|
||||
config = DelightfulTTSConfig(
|
||||
audio=audio_config,
|
||||
batch_size=2,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
binary_align_loss_alpha=0.0,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
use_attn_priors=False,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
use_speaker_embedding=False,
|
||||
)
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{'cpu'}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs -1"
|
||||
)
|
||||
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == -1
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
||||
shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/")
|
Loading…
Reference in New Issue