mirror of https://github.com/coqui-ai/TTS.git
update vocoder torch to tf conversion
This commit is contained in:
parent
042cde15d6
commit
7721d4b230
|
@ -100,13 +100,18 @@ for i in range(1, len(model.layers)):
|
|||
diff = compare_torch_tf(out_torch, out_tf_)
|
||||
assert diff < 1e-5, diff
|
||||
|
||||
dummy_input_torch = torch.ones((1, 80, 10))
|
||||
torch.manual_seed(0)
|
||||
dummy_input_torch = torch.rand((1, 80, 100))
|
||||
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||
model.inference_padding = 0
|
||||
model_tf.inference_padding = 0
|
||||
output_torch = model.inference(dummy_input_torch)
|
||||
output_tf = model_tf(dummy_input_tf, training=False)
|
||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
||||
output_torch, output_tf)
|
||||
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
||||
args.output_path)
|
||||
print(' > Model conversion is successfully completed :).')
|
||||
|
||||
|
|
Loading…
Reference in New Issue