docs(tts.models.vits): clarify use of discriminator/generator

[ci skip]
This commit is contained in:
Enno Hermann 2024-03-12 18:06:50 +01:00
parent 0c6c20f52f
commit 89a061f1d1
1 changed files with 6 additions and 3 deletions

View File

@ -1233,7 +1233,7 @@ class Vits(BaseTTS):
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
@ -1651,13 +1651,16 @@ class Vits(BaseTTS):
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
It returns 2 optimizers in a list. First one is for the discriminator
and the second one is for the generator.
Returns:
List: optimizers.
"""
# select generator parameters
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
# select generator parameters
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters