fixing rebase issues

This commit is contained in:
erogol 2020-08-05 18:33:22 +02:00
parent 5c752799ae
commit fe081d4f7c
5 changed files with 24 additions and 24 deletions

View File

@ -303,7 +303,7 @@ class Decoder(nn.Module):
self.separate_stopnet = separate_stopnet
self.query_dim = 256
# memory -> |Prenet| -> processed_memory
prenet_dim = memory_dim * self.memory_size if self.use_memory_queue else memory_dim
prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels
self.prenet = Prenet(
prenet_dim,
prenet_type,

View File

@ -141,7 +141,6 @@ class Decoder(nn.Module):
location_attn (bool): if true, use location sensitive attention.
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
@ -156,7 +155,6 @@ class Decoder(nn.Module):
self.separate_stopnet = separate_stopnet
self.max_decoder_steps = 1000
self.stop_threshold = 0.5
self.speaker_embedding_dim = speaker_embedding_dim
# model dimensions
self.query_dim = 1024

View File

@ -53,27 +53,16 @@
"max_seq_len": 300,
"log_dir": "tests/outputs/",
<<<<<<< HEAD
"use_speaker_embedding": false,
"use_gst": false,
"gst": {
"gst_style_input": null,
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10
}
=======
// MULTI-SPEAKER and GST
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
"use_gst": true, // use global style tokens
"gst": { // gst parameter if gst is enabled
"gst_style_input": null, // Condition the style input either on a
// -> wave file [path to wave] or
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
"gst_style_input": null, // Condition the style input either on a
// -> wave file [path to wave] or
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
// with the dictionary being len(dict) <= len(gst_style_tokens).
"gst_embedding_dim": 512,
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10
}
>>>>>>> travis unit tests fix and add Tacotron and Tacotron 2 GST and MultiSpeaker Tests
}
}
}

View File

@ -83,6 +83,19 @@
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
"use_speaker_embedding": false, // whether to use additional embeddings for separate speakers
// MULTI-SPEAKER and GST
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
"use_gst": true, // use global style tokens
"gst": { // gst parameter if gst is enabled
"gst_style_input": null, // Condition the style input either on a
// -> wave file [path to wave] or
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
// with the dictionary being len(dict) <= len(gst_style_tokens).
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10
}
}

View File

@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase):
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
@ -53,7 +53,7 @@ class TacotronTrainTest(unittest.TestCase):
model = Tacotron(
num_chars=32,
num_speakers=5,
postnet_output_dim=c.audio['num_freq'],
postnet_output_dim=c.audio['fft_size'],
decoder_output_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size
@ -97,7 +97,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
mel_lengths[-1] = 120
stop_targets = torch.zeros(8, 120, 1).float().to(device)