mirror of https://github.com/coqui-ai/TTS.git
build inference graph for melgan tf
This commit is contained in:
parent
1dc7456dc4
commit
963ffbd003
|
@ -83,6 +83,7 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
|
|||
# self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers")
|
||||
self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, c, training=False):
|
||||
"""
|
||||
c : B x C x T
|
||||
|
@ -94,10 +95,15 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
|
|||
def inference(self, c):
|
||||
c = tf.transpose(c, perm=[0, 2, 1])
|
||||
c = tf.expand_dims(c, 2)
|
||||
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||
# FIXME: TF had no replicate padding as in Torch
|
||||
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||
o = c
|
||||
for layer in self.model_layers:
|
||||
o = layer(o)
|
||||
# o = self.model_layers(c)
|
||||
o = tf.transpose(o, perm=[0, 3, 2, 1])
|
||||
return o[:, :, 0, :]
|
||||
return o[:, :, 0, :]
|
||||
|
||||
def build_inference(self):
|
||||
x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32)
|
||||
self(x, training=False)
|
|
@ -37,7 +37,8 @@ class MultibandMelganGenerator(MelganGenerator): # pylint: disable=too-many-anc
|
|||
def inference(self, c):
|
||||
c = tf.transpose(c, perm=[0, 2, 1])
|
||||
c = tf.expand_dims(c, 2)
|
||||
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||
# FIXME: TF had no replicate padding as in Torch
|
||||
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||
o = c
|
||||
for layer in self.model_layers:
|
||||
o = layer(o)
|
||||
|
|
Loading…
Reference in New Issue