mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #19 from eginhard/fix-vits-comments
docs(tts.models.vits): clarify use of discriminator/generator
This commit is contained in:
commit
eaa7283244
|
@ -1233,7 +1233,7 @@ class Vits(BaseTTS):
|
|||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||
optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
|
@ -1651,13 +1651,16 @@ class Vits(BaseTTS):
|
|||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
It returns 2 optimizers in a list. First one is for the discriminator
|
||||
and the second one is for the generator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
# select generator parameters
|
||||
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||
|
||||
# select generator parameters
|
||||
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
|
|
Loading…
Reference in New Issue