Fixup overflow (#2218)

* Update overflow config

* Pulling shuffle and drop_last  from config

* Print training stats for overflow
This commit is contained in:
Eren Gölge 2022-12-15 00:56:48 +01:00 committed by GitHub
parent ecea43ec81
commit a9167cf239
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 6 deletions

View File

@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
lr_scheduler: str = None lr_scheduler: str = None
# overrides # overrides
min_seq_len: int = 3 min_text_len: int = 10
max_seq_len: int = 500 max_text_len: int = 500
min_audio_len: int = 512
# testing # testing
test_sentences: List[str] = field( test_sentences: List[str] = field(

View File

@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False. Defaults to False.
shuffle (bool):
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
drop_last (bool):
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
issues that emerge from the partial batch statistics. Defaults to True.
add_blank (bool): add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence. of slower run-time due to the longer input sequence.
@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig):
precompute_num_workers: int = 0 precompute_num_workers: int = 0
use_noise_augment: bool = False use_noise_augment: bool = False
start_by_longest: bool = False start_by_longest: bool = False
shuffle: bool = False
drop_last: bool = False
# dataset # dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer # optimizer

View File

@ -345,9 +345,9 @@ class BaseTTS(BaseTrainerModel):
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size, batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=True, # if there is no other sampler shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, # setting this False might cause issues in AMP training. drop_last=config.drop_last, # setting this False might cause issues in AMP training.
sampler=sampler, sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False, pin_memory=False,

View File

@ -159,6 +159,15 @@ class Overflow(BaseTTS):
return outputs return outputs
@staticmethod
def _training_stats(batch):
stats = {}
stats["avg_text_length"] = batch["text_lengths"].float().mean()
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
return stats
def train_step(self, batch: dict, criterion: nn.Module): def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch["text_input"] text_input = batch["text_input"]
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
@ -171,9 +180,10 @@ class Overflow(BaseTTS):
mels=mel_input, mels=mel_input,
mel_len=mel_lengths, mel_len=mel_lengths,
) )
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))
loss_dict = criterion(outputs["log_probs"]) # for printing useful statistics on terminal
loss_dict.update(self._training_stats(batch))
return outputs, loss_dict return outputs, loss_dict
def eval_step(self, batch: Dict, criterion: nn.Module): def eval_step(self, batch: Dict, criterion: nn.Module):