From a755328e4965f8b4d9ff033853e1048b25141865 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 20 Jun 2024 14:05:19 +0200 Subject: [PATCH] refactor(freevc): remove duplicate sequence_mask --- TTS/vc/models/freevc.py | 3 ++- TTS/vc/modules/freevc/commons.py | 11 +---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index ec7cc0e0..36f4017c 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -14,6 +14,7 @@ from torch.nn.utils.parametrize import remove_parametrizations import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.io import load_fsspec from TTS.vc.configs.freevc_config import FreeVCConfig @@ -80,7 +81,7 @@ class Encoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index e7813513..89872800 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -3,7 +3,7 @@ import math import torch from torch.nn import functional as F -from TTS.tts.utils.helpers import convert_pad_shape +from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask def init_weights(m, mean=0.0, std=0.01): @@ -115,20 +115,11 @@ def shift_1d(x): return x -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - def generate_path(duration, mask): """ duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device - b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1)