build inference graph for melgan tf

This commit is contained in:
erogol 2020-07-08 10:29:40 +02:00
parent 1dc7456dc4
commit 963ffbd003
2 changed files with 10 additions and 3 deletions

View File

@ -83,6 +83,7 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
# self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers")
self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers
@tf.function(experimental_relax_shapes=True)
def call(self, c, training=False):
"""
c : B x C x T
@ -94,10 +95,15 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
def inference(self, c):
c = tf.transpose(c, perm=[0, 2, 1])
c = tf.expand_dims(c, 2)
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
# FIXME: TF had no replicate padding as in Torch
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
o = c
for layer in self.model_layers:
o = layer(o)
# o = self.model_layers(c)
o = tf.transpose(o, perm=[0, 3, 2, 1])
return o[:, :, 0, :]
return o[:, :, 0, :]
def build_inference(self):
x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32)
self(x, training=False)

View File

@ -37,7 +37,8 @@ class MultibandMelganGenerator(MelganGenerator): # pylint: disable=too-many-anc
def inference(self, c):
c = tf.transpose(c, perm=[0, 2, 1])
c = tf.expand_dims(c, 2)
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
# FIXME: TF had no replicate padding as in Torch
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
o = c
for layer in self.model_layers:
o = layer(o)