mirror of https://github.com/coqui-ai/TTS.git
Fixup overflow (#2218)
* Update overflow config * Pulling shuffle and drop_last from config * Print training stats for overflow
This commit is contained in:
parent
ecea43ec81
commit
a9167cf239
|
@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case
|
|||
lr_scheduler: str = None
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 3
|
||||
max_seq_len: int = 500
|
||||
min_text_len: int = 10
|
||||
max_text_len: int = 500
|
||||
min_audio_len: int = 512
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
|
|
|
@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||
Defaults to False.
|
||||
|
||||
shuffle (bool):
|
||||
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
|
||||
|
||||
drop_last (bool):
|
||||
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
|
||||
issues that emerge from the partial batch statistics. Defaults to True.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
start_by_longest: bool = False
|
||||
shuffle: bool = False
|
||||
drop_last: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
|
|
|
@ -345,9 +345,9 @@ class BaseTTS(BaseTrainerModel):
|
|||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=True, # if there is no other sampler
|
||||
shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
|
|
@ -159,6 +159,15 @@ class Overflow(BaseTTS):
|
|||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _training_stats(batch):
|
||||
stats = {}
|
||||
stats["avg_text_length"] = batch["text_lengths"].float().mean()
|
||||
stats["avg_spec_length"] = batch["mel_lengths"].float().mean()
|
||||
stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean()
|
||||
stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean()
|
||||
return stats
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -171,9 +180,10 @@ class Overflow(BaseTTS):
|
|||
mels=mel_input,
|
||||
mel_len=mel_lengths,
|
||||
)
|
||||
loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum()))
|
||||
|
||||
loss_dict = criterion(outputs["log_probs"])
|
||||
|
||||
# for printing useful statistics on terminal
|
||||
loss_dict.update(self._training_stats(batch))
|
||||
return outputs, loss_dict
|
||||
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module):
|
||||
|
|
Loading…
Reference in New Issue