mirror of https://github.com/coqui-ai/TTS.git
update vocoder torch to tf conversion
This commit is contained in:
parent
042cde15d6
commit
7721d4b230
|
@ -100,13 +100,18 @@ for i in range(1, len(model.layers)):
|
||||||
diff = compare_torch_tf(out_torch, out_tf_)
|
diff = compare_torch_tf(out_torch, out_tf_)
|
||||||
assert diff < 1e-5, diff
|
assert diff < 1e-5, diff
|
||||||
|
|
||||||
dummy_input_torch = torch.ones((1, 80, 10))
|
torch.manual_seed(0)
|
||||||
|
dummy_input_torch = torch.rand((1, 80, 100))
|
||||||
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||||
|
model.inference_padding = 0
|
||||||
|
model_tf.inference_padding = 0
|
||||||
output_torch = model.inference(dummy_input_torch)
|
output_torch = model.inference(dummy_input_torch)
|
||||||
output_tf = model_tf(dummy_input_tf, training=False)
|
output_tf = model_tf(dummy_input_tf, training=False)
|
||||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
||||||
output_torch, output_tf)
|
output_torch, output_tf)
|
||||||
|
|
||||||
# save tf model
|
# save tf model
|
||||||
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
||||||
args.output_path)
|
args.output_path)
|
||||||
print(' > Model conversion is successfully completed :).')
|
print(' > Model conversion is successfully completed :).')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue