mirror of https://github.com/coqui-ai/TTS.git
use native amp
This commit is contained in:
parent
8a820930c6
commit
1229554c42
|
@ -26,8 +26,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
|
|
@ -28,8 +28,7 @@ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
|||
init_distributed, reduce_tensor)
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
|
|
@ -17,8 +17,7 @@ from TTS.utils.console_logger import ConsoleLogger
|
|||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
|
|
|
@ -65,7 +65,7 @@
|
|||
"eval_batch_size":16,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"apex_amp_level": null, // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, and use "O1" to activate.
|
||||
"mixed_precision": true, // level of optimization with NVIDIA's apex feature for automatic mixed FP16/FP32 precision (AMP), NOTE: currently only O1 is supported, and use "O1" to activate.
|
||||
|
||||
// LOSS SETTINGS
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
|
|
@ -4,22 +4,10 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import contextlib
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def set_amp_context(mixed_precision):
|
||||
if mixed_precision:
|
||||
cm = torch.cuda.amp.autocast()
|
||||
else:
|
||||
# if platform.python_version() <= "3.6.0":
|
||||
cm = contextlib.suppress()
|
||||
# else:
|
||||
# cm = contextlib.nullcontext()
|
||||
return cm
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
|
|
Loading…
Reference in New Issue