diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4add9fbf..2c1c2bc6 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -890,9 +890,7 @@ class Vits(BaseTTS): spec_segment_size = spec_segment_size * int(self.interpolate_factor) # interpolate z if needed if self.args.interpolate_z: - z = torch.nn.functional.interpolate( - z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest" - ).squeeze(0) + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) # recompute the mask if needed if y_lengths is not None and y_mask is not None: y_mask = (