From 0971bc236ea41e22970764b11dacefcd8f2273b8 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 7 Nov 2024 00:33:54 +0100 Subject: [PATCH] refactor: use external package for monotonic alignment --- .../layers/delightful_tts/acoustic_model.py | 3 +- TTS/tts/models/align_tts.py | 3 +- TTS/tts/models/forward_tts.py | 3 +- TTS/tts/models/glow_tts.py | 3 +- TTS/tts/models/vits.py | 3 +- TTS/tts/utils/helpers.py | 74 ------------------- TTS/tts/utils/monotonic_align/__init__.py | 0 TTS/tts/utils/monotonic_align/core.pyx | 47 ------------ pyproject.toml | 1 + 9 files changed, 11 insertions(+), 126 deletions(-) delete mode 100644 TTS/tts/utils/monotonic_align/__init__.py delete mode 100644 TTS/tts/utils/monotonic_align/core.pyx diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index 83989f9b..3c0e3a3a 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -5,6 +5,7 @@ from typing import Callable, Dict, Tuple import torch import torch.nn.functional as F from coqpit import Coqpit +from monotonic_alignment_search import maximum_path from torch import nn from TTS.tts.layers.delightful_tts.conformer import Conformer @@ -19,7 +20,7 @@ from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProso from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor from TTS.tts.layers.generic.aligner import AlignmentNetwork -from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.helpers import generate_path, sequence_mask logger = logging.getLogger(__name__) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2d27a578..1c3d5758 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -3,6 +3,7 @@ from typing import Dict, List, Union import torch from coqpit import Coqpit +from monotonic_alignment_search import maximum_path from torch import nn from trainer.io import load_fsspec @@ -12,7 +13,7 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.helpers import generate_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 4b74462d..e7bc8637 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit +from monotonic_alignment_search import maximum_path from torch import nn from torch.cuda.amp.autocast_mode import autocast from trainer.io import load_fsspec @@ -14,7 +15,7 @@ from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 64954d28..5ea69865 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit +from monotonic_alignment_search import maximum_path from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F @@ -13,7 +14,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.helpers import generate_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b014e4fd..af803a0f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -11,6 +11,7 @@ import torch.distributed as dist import torchaudio from coqpit import Coqpit from librosa.filters import mel as librosa_mel_fn +from monotonic_alignment_search import maximum_path from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F @@ -28,7 +29,7 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint -from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.helpers import generate_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 7429d0fc..d1722501 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -3,13 +3,6 @@ import torch from scipy.stats import betabinom from torch.nn import functional as F -try: - from TTS.tts.utils.monotonic_align.core import maximum_path_c - - CYTHON = True -except ModuleNotFoundError: - CYTHON = False - class StandardScaler: """StandardScaler for mean-scale normalization with the given mean and scale values.""" @@ -168,73 +161,6 @@ def generate_path(duration, mask): return path -def maximum_path(value, mask): - if CYTHON: - return maximum_path_cython(value, mask) - return maximum_path_numpy(value, mask) - - -def maximum_path_cython(value, mask): - """Cython optimised version. - Shapes: - - value: :math:`[B, T_en, T_de]` - - mask: :math:`[B, T_en, T_de]` - """ - value = value * mask - device = value.device - dtype = value.dtype - value = value.data.cpu().numpy().astype(np.float32) - path = np.zeros_like(value).astype(np.int32) - mask = mask.data.cpu().numpy() - - t_x_max = mask.sum(1)[:, 0].astype(np.int32) - t_y_max = mask.sum(2)[:, 0].astype(np.int32) - maximum_path_c(path, value, t_x_max, t_y_max) - return torch.from_numpy(path).to(device=device, dtype=dtype) - - -def maximum_path_numpy(value, mask, max_neg_val=None): - """ - Monotonic alignment search algorithm - Numpy-friendly version. It's about 4 times faster than torch version. - value: [b, t_x, t_y] - mask: [b, t_x, t_y] - """ - if max_neg_val is None: - max_neg_val = -np.inf # Patch for Sphinx complaint - value = value * mask - - device = value.device - dtype = value.dtype - value = value.cpu().detach().numpy() - mask = mask.cpu().detach().numpy().astype(bool) - - b, t_x, t_y = value.shape - direction = np.zeros(value.shape, dtype=np.int64) - v = np.zeros((b, t_x), dtype=np.float32) - x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) - for j in range(t_y): - v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] - v1 = v - max_mask = v1 >= v0 - v_max = np.where(max_mask, v1, v0) - direction[:, :, j] = max_mask - - index_mask = x_range <= j - v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) - direction = np.where(mask, direction, 1) - - path = np.zeros(value.shape, dtype=np.float32) - index = mask[:, :, 0].sum(1).astype(np.int64) - 1 - index_range = np.arange(b) - for j in reversed(range(t_y)): - path[index_range, index, j] = 1 - index = index + direction[index_range, index, j] - 1 - path = path * mask.astype(np.float32) - path = torch.from_numpy(path).to(device=device, dtype=dtype) - return path - - def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): P, M = phoneme_count, mel_count x = np.arange(0, P) diff --git a/TTS/tts/utils/monotonic_align/__init__.py b/TTS/tts/utils/monotonic_align/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx deleted file mode 100644 index 091fcc3a..00000000 --- a/TTS/tts/utils/monotonic_align/core.pyx +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -cimport cython -cimport numpy as np - -from cython.parallel import prange - - -@cython.boundscheck(False) -@cython.wraparound(False) -cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: - cdef int x - cdef int y - cdef float v_prev - cdef float v_cur - cdef float tmp - cdef int index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[x, y-1] - if x == 0: - if y == 0: - v_prev = 0. - else: - v_prev = max_neg_val - else: - v_prev = value[x-1, y-1] - value[x, y] = max(v_cur, v_prev) + value[x, y] - - for y in range(t_y - 1, -1, -1): - path[index, y] = 1 - if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): - index = index - 1 - - -@cython.boundscheck(False) -@cython.wraparound(False) -cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: - cdef int b = values.shape[0] - - cdef int i - for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/pyproject.toml b/pyproject.toml index 23387fd3..d13e2145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ # Coqui stack "coqui-tts-trainer>=0.1.4,<0.2.0", "coqpit>=0.0.16", + "monotonic-alignment-search>=0.1.0", # Gruut + supported languages "gruut[de,es,fr]>=2.4.0", # Tortoise