From 5840d89802dfaf4ff03a382d972628afd9647280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 3 Jan 2022 15:03:34 +0000 Subject: [PATCH] Keep proj_dim in speaker encoder models --- TTS/speaker_encoder/models/lstm.py | 1 + TTS/speaker_encoder/models/resnet.py | 1 + 2 files changed, 2 insertions(+) diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index 7ac08514..ec394cdb 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -49,6 +49,7 @@ class LSTMSpeakerEncoder(nn.Module): self.use_lstm_with_projection = use_lstm_with_projection self.use_torch_spec = use_torch_spec self.audio_config = audio_config + self.proj_dim = proj_dim layers = [] # choise LSTM layer diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 643449c8..d6c3dad4 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -95,6 +95,7 @@ class ResNetSpeakerEncoder(nn.Module): self.log_input = log_input self.use_torch_spec = use_torch_spec self.audio_config = audio_config + self.proj_dim = proj_dim self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU(inplace=True)