mirror of https://github.com/coqui-ai/TTS.git
maximum_path_numpy and CYTHON adabtable import
This commit is contained in:
parent
877f0bbfba
commit
660d61aeeb
|
@ -2,7 +2,13 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
|
|
||||||
|
try:
|
||||||
|
# TODO: fix pypi cython installation problem.
|
||||||
|
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
|
||||||
|
CYTHON = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
CYTHON = False
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
def convert_pad_shape(pad_shape):
|
||||||
|
@ -32,6 +38,12 @@ def generate_path(duration, mask):
|
||||||
|
|
||||||
|
|
||||||
def maximum_path(value, mask):
|
def maximum_path(value, mask):
|
||||||
|
if CYTHON:
|
||||||
|
return maximum_path_cython(value, mask)
|
||||||
|
return maximum_path_numpy(value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path_cython(value, mask):
|
||||||
""" Cython optimised version.
|
""" Cython optimised version.
|
||||||
value: [b, t_x, t_y]
|
value: [b, t_x, t_y]
|
||||||
mask: [b, t_x, t_y]
|
mask: [b, t_x, t_y]
|
||||||
|
@ -47,3 +59,45 @@ def maximum_path(value, mask):
|
||||||
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
||||||
maximum_path_c(path, value, t_x_max, t_y_max)
|
maximum_path_c(path, value, t_x_max, t_y_max)
|
||||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path_numpy(value, mask, max_neg_val=None):
|
||||||
|
"""
|
||||||
|
Monotonic alignment search algorithm
|
||||||
|
Numpy-friendly version. It's about 4 times faster than torch version.
|
||||||
|
value: [b, t_x, t_y]
|
||||||
|
mask: [b, t_x, t_y]
|
||||||
|
"""
|
||||||
|
if max_neg_val is None:
|
||||||
|
max_neg_val = -np.inf # Patch for Sphinx complaint
|
||||||
|
value = value * mask
|
||||||
|
|
||||||
|
device = value.device
|
||||||
|
dtype = value.dtype
|
||||||
|
value = value.cpu().detach().numpy()
|
||||||
|
mask = mask.cpu().detach().numpy().astype(np.bool)
|
||||||
|
|
||||||
|
b, t_x, t_y = value.shape
|
||||||
|
direction = np.zeros(value.shape, dtype=np.int64)
|
||||||
|
v = np.zeros((b, t_x), dtype=np.float32)
|
||||||
|
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
|
||||||
|
for j in range(t_y):
|
||||||
|
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
|
||||||
|
v1 = v
|
||||||
|
max_mask = v1 >= v0
|
||||||
|
v_max = np.where(max_mask, v1, v0)
|
||||||
|
direction[:, :, j] = max_mask
|
||||||
|
|
||||||
|
index_mask = x_range <= j
|
||||||
|
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
||||||
|
direction = np.where(mask, direction, 1)
|
||||||
|
|
||||||
|
path = np.zeros(value.shape, dtype=np.float32)
|
||||||
|
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
||||||
|
index_range = np.arange(b)
|
||||||
|
for j in reversed(range(t_y)):
|
||||||
|
path[index_range, index, j] = 1
|
||||||
|
index = index + direction[index_range, index, j] - 1
|
||||||
|
path = path * mask.astype(np.float32)
|
||||||
|
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||||
|
return path
|
||||||
|
|
Loading…
Reference in New Issue