diff --git a/train.py b/train.py index 3b6ff866..1ccfaead 100644 --- a/train.py +++ b/train.py @@ -108,6 +108,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, speaker_ids.append(speaker_mapping[speaker_name]) speaker_ids = torch.LongTensor(speaker_ids) + if len(speaker_mapping) > c.num_speakers: + raise ValueError("It seems there are at least {} speakers in " + "your dataset, while 'num_speakers' is set to {}. " + "Found the following speakers: {}".format(len(speaker_mapping), + c.num_speakers, + list(speaker_mapping))) + # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)