mirror of https://github.com/coqui-ai/TTS.git
docs(tts.models.vits): clarify use of discriminator/generator
[ci skip]
This commit is contained in:
parent
0c6c20f52f
commit
89a061f1d1
|
@ -1233,7 +1233,7 @@ class Vits(BaseTTS):
|
||||||
Args:
|
Args:
|
||||||
batch (Dict): Input tensors.
|
batch (Dict): Input tensors.
|
||||||
criterion (nn.Module): Loss layer designed for the model.
|
criterion (nn.Module): Loss layer designed for the model.
|
||||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||||
|
@ -1651,13 +1651,16 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
|
||||||
|
It returns 2 optimizers in a list. First one is for the discriminator
|
||||||
|
and the second one is for the generator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: optimizers.
|
List: optimizers.
|
||||||
"""
|
"""
|
||||||
# select generator parameters
|
|
||||||
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||||
|
|
||||||
|
# select generator parameters
|
||||||
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
||||||
optimizer1 = get_optimizer(
|
optimizer1 = get_optimizer(
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
|
|
Loading…
Reference in New Issue