coqui-tts/TTS/tts/utils/fast_speech.py

34 lines
1.3 KiB
Python

import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
class DurationCalculator():
def calculate_durations(self, att_ws, ilens, olens):
"""calculate duration from given alignment matrices"""
durations = [self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_sequence(durations, batch_first=True)
@staticmethod
def _calculate_duration(att_w, ilen, olen):
'''
attw : batch x outs x ins
'''
durations = torch.stack([att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)])
return durations
def calculate_scores(self, att_ws, ilens, olens):
"""calculate scores per duration step"""
scores = [self._calculate_scores(att_w, ilen, olen, self.K) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_list(scores, 0)
@staticmethod
def _calculate_scores(att_w, ilen, olen, k):
# which input is attended for each output
scores = [None] * ilen
values, idxs = att_w[:olen, :ilen].max(-1)
for i in range(ilen):
vals = values[torch.where(idxs == i)]
scores[i] = vals
scores = [torch.nn.functional.pad(score, (0, k - score.shape[0])) for score in scores]
return torch.stack(scores)