mirror of https://github.com/coqui-ai/TTS.git
Fixup overflow (#2218)
* Update overflow config * Pulling shuffle and drop_last from config * Print training stats for overflow
This commit is contained in:
parent
ecea43ec81
commit
a9167cf239
|
@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
|
||||||
lr_scheduler: str = None
|
lr_scheduler: str = None
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 3
|
min_text_len: int = 10
|
||||||
max_seq_len: int = 500
|
max_text_len: int = 500
|
||||||
|
min_audio_len: int = 512
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(
|
test_sentences: List[str] = field(
|
||||||
|
|
|
@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
|
||||||
|
shuffle (bool):
|
||||||
|
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
|
||||||
|
|
||||||
|
drop_last (bool):
|
||||||
|
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
|
||||||
|
issues that emerge from the partial batch statistics. Defaults to True.
|
||||||
|
|
||||||
add_blank (bool):
|
add_blank (bool):
|
||||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||||
of slower run-time due to the longer input sequence.
|
of slower run-time due to the longer input sequence.
|
||||||
|
@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
precompute_num_workers: int = 0
|
precompute_num_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
start_by_longest: bool = False
|
start_by_longest: bool = False
|
||||||
|
shuffle: bool = False
|
||||||
|
drop_last: bool = False
|
||||||
# dataset
|
# dataset
|
||||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||||
# optimizer
|
# optimizer
|
||||||
|
|
|
@ -345,9 +345,9 @@ class BaseTTS(BaseTrainerModel):
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
shuffle=True, # if there is no other sampler
|
shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=False, # setting this False might cause issues in AMP training.
|
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
|
|
@ -159,6 +159,15 @@ class Overflow(BaseTTS):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _training_stats(batch):
|
||||||
|
stats = {}
|
||||||
|
stats["avg_text_length"] = batch["text_lengths"].float().mean()
|
||||||
|
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
|
||||||
|
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
|
||||||
|
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
|
||||||
|
return stats
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module):
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -171,9 +180,10 @@ class Overflow(BaseTTS):
|
||||||
mels=mel_input,
|
mels=mel_input,
|
||||||
mel_len=mel_lengths,
|
mel_len=mel_lengths,
|
||||||
)
|
)
|
||||||
|
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))
|
||||||
|
|
||||||
loss_dict = criterion(outputs["log_probs"])
|
# for printing useful statistics on terminal
|
||||||
|
loss_dict.update(self._training_stats(batch))
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def eval_step(self, batch: Dict, criterion: nn.Module):
|
def eval_step(self, batch: Dict, criterion: nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue