diff --git a/TTS/tts/tf/layers/tacotron/common_layers.py b/TTS/tts/tf/layers/tacotron/common_layers.py new file mode 100644 index 00000000..ad18b9fc --- /dev/null +++ b/TTS/tts/tf/layers/tacotron/common_layers.py @@ -0,0 +1,288 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.python.ops import math_ops +# from tensorflow_addons.seq2seq import BahdanauAttention + +# NOTE: linter has a problem with the current TF release +#pylint: disable=no-value-for-parameter +#pylint: disable=unexpected-keyword-arg + +class Linear(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(Linear, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.activation = keras.layers.ReLU() + + def call(self, x): + """ + shapes: + x: B x T x C + """ + return self.activation(self.linear_layer(x)) + + +class LinearBN(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(LinearBN, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.activation = keras.layers.ReLU() + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + out = self.linear_layer(x) + out = self.batch_normalization(out, training=training) + return self.activation(out) + + +class Prenet(keras.layers.Layer): + def __init__(self, + prenet_type, + prenet_dropout, + units, + bias, + **kwargs): + super(Prenet, self).__init__(**kwargs) + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.linear_layers = [] + if prenet_type == "bn": + self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + elif prenet_type == "original": + self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + else: + raise RuntimeError(' [!] Unknown prenet type.') + if prenet_dropout: + self.dropout = keras.layers.Dropout(rate=0.5) + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + for linear in self.linear_layers: + if self.prenet_dropout: + x = self.dropout(linear(x), training=training) + else: + x = linear(x) + return x + + +def _sigmoid_norm(score): + attn_weights = tf.nn.sigmoid(score) + attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) + return attn_weights + + +class Attention(keras.layers.Layer): + """TODO: implement forward_attention + TODO: location sensitive attention + TODO: implement attention windowing """ + def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, + loc_attn_kernel_size, use_windowing, norm, use_forward_attn, + use_trans_agent, use_forward_attn_mask, **kwargs): + super(Attention, self).__init__(**kwargs) + self.use_loc_attn = use_loc_attn + self.loc_attn_n_filters = loc_attn_n_filters + self.loc_attn_kernel_size = loc_attn_kernel_size + self.use_windowing = use_windowing + self.norm = norm + self.use_forward_attn = use_forward_attn + self.use_trans_agent = use_trans_agent + self.use_forward_attn_mask = use_forward_attn_mask + self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') + self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') + self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') + if use_loc_attn: + self.location_conv1d = keras.layers.Conv1D( + filters=loc_attn_n_filters, + kernel_size=loc_attn_kernel_size, + padding='same', + use_bias=False, + name='location_layer/location_conv1d') + self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') + if norm == 'softmax': + self.norm_func = tf.nn.softmax + elif norm == 'sigmoid': + self.norm_func = _sigmoid_norm + else: + raise ValueError("Unknown value for attention norm type") + + def init_states(self, batch_size, value_length): + states = [] + if self.use_loc_attn: + attention_cum = tf.zeros([batch_size, value_length]) + attention_old = tf.zeros([batch_size, value_length]) + states = [attention_cum, attention_old] + if self.use_forward_attn: + alpha = tf.concat([ + tf.ones([batch_size, 1]), + tf.zeros([batch_size, value_length])[:, :-1] + 1e-7 + ], 1) + states.append(alpha) + return tuple(states) + + def process_values(self, values): + """ cache values for decoder iterations """ + #pylint: disable=attribute-defined-outside-init + self.processed_values = self.inputs_layer(values) + self.values = values + + def get_loc_attn(self, query, states): + """ compute location attention, query layer and + unnorm. attention weights""" + attention_cum, attention_old = states[:2] + attn_cat = tf.stack([attention_old, attention_cum], axis=2) + + processed_query = self.query_layer(tf.expand_dims(query, 1)) + processed_attn = self.location_dense(self.location_conv1d(attn_cat)) + score = self.v( + tf.nn.tanh(self.processed_values + processed_query + + processed_attn)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def get_attn(self, query): + """ compute query layer and unnormalized attention weights """ + processed_query = self.query_layer(tf.expand_dims(query, 1)) + score = self.v(tf.nn.tanh(self.processed_values + processed_query)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def apply_score_masking(self, score, mask): #pylint: disable=no-self-use + """ ignore sequence paddings """ + padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) + # Bias so padding positions do not contribute to attention distribution. + score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) + return score + + def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use + # forward attention + fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) + # compute transition potentials + new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment + # renormalize attention weights + new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) + return new_alpha + + def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): + states = [] + if self.use_loc_attn: + states = [old_states[0] + scores_norm, attn_weights] + if self.use_forward_attn: + states.append(new_alpha) + return tuple(states) + + def call(self, query, states): + """ + shapes: + query: B x D + """ + if self.use_loc_attn: + score, _ = self.get_loc_attn(query, states) + else: + score, _ = self.get_attn(query) + + # TODO: masking + # if mask is not None: + # self.apply_score_masking(score, mask) + # attn_weights shape == (batch_size, max_length, 1) + + # normalize attention scores + scores_norm = self.norm_func(score) + attn_weights = scores_norm + + # apply forward attention + new_alpha = None + if self.use_forward_attn: + new_alpha = self.apply_forward_attention(attn_weights, states[-1]) + attn_weights = new_alpha + + # update states tuple + # states = (cum_attn_weights, attn_weights, new_alpha) + states = self.update_states(states, scores_norm, attn_weights, new_alpha) + + # context_vector shape after sum == (batch_size, hidden_size) + context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) + context_vector = tf.squeeze(context_vector, axis=1) + return context_vector, attn_weights, states + + +# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): +# dtype = processed_query.dtype +# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] +# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) + + +# class LocationSensitiveAttention(BahdanauAttention): +# def __init__(self, +# units, +# memory=None, +# memory_sequence_length=None, +# normalize=False, +# probability_fn="softmax", +# kernel_initializer="glorot_uniform", +# dtype=None, +# name="LocationSensitiveAttention", +# location_attention_filters=32, +# location_attention_kernel_size=31): + +# super(LocationSensitiveAttention, +# self).__init__(units=units, +# memory=memory, +# memory_sequence_length=memory_sequence_length, +# normalize=normalize, +# probability_fn='softmax', ## parent module default +# kernel_initializer=kernel_initializer, +# dtype=dtype, +# name=name) +# if probability_fn == 'sigmoid': +# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) +# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) +# self.location_dense = keras.layers.Dense(units, use_bias=False) +# # self.v = keras.layers.Dense(1, use_bias=True) + +# def _location_sensitive_score(self, processed_query, keys, processed_loc): +# processed_query = tf.expand_dims(processed_query, 1) +# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) + +# def _location_sensitive(self, alignment_cum, alignment_old): +# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) +# return self.location_dense(self.location_conv(alignment_cat)) + +# def _sigmoid_normalization(self, score): +# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) + +# # def _apply_masking(self, score, mask): +# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) +# # # Bias so padding positions do not contribute to attention distribution. +# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) +# # return score + +# def _calculate_attention(self, query, state): +# alignment_cum, alignment_old = state[:2] +# processed_query = self.query_layer( +# query) if self.query_layer else query +# processed_loc = self._location_sensitive(alignment_cum, alignment_old) +# score = self._location_sensitive_score( +# processed_query, +# self.keys, +# processed_loc) +# alignment = self.probability_fn(score, state) +# alignment_cum = alignment_cum + alignment +# state[0] = alignment_cum +# state[1] = alignment +# return alignment, state + +# def compute_context(self, alignments): +# expanded_alignments = tf.expand_dims(alignments, 1) +# context = tf.matmul(expanded_alignments, self.values) +# context = tf.squeeze(context, [1]) +# return context + +# # def call(self, query, state): +# # alignment, next_state = self._calculate_attention(query, state) +# # return alignment, next_state diff --git a/TTS/tts/tf/layers/tacotron/tacotron2.py b/TTS/tts/tf/layers/tacotron/tacotron2.py new file mode 100644 index 00000000..094d9e4a --- /dev/null +++ b/TTS/tts/tf/layers/tacotron/tacotron2.py @@ -0,0 +1,302 @@ +import tensorflow as tf +from tensorflow import keras +from TTS.tts.tf.utils.tf_utils import shape_list +from TTS.tts.tf.layers.tacotron.common_layers import Prenet, Attention + + +# NOTE: linter has a problem with the current TF release +#pylint: disable=no-value-for-parameter +#pylint: disable=unexpected-keyword-arg +class ConvBNBlock(keras.layers.Layer): + def __init__(self, filters, kernel_size, activation, **kwargs): + super(ConvBNBlock, self).__init__(**kwargs) + self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') + self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') + self.activation = keras.layers.Activation(activation, name='activation') + + def call(self, x, training=None): + o = self.convolution1d(x) + o = self.batch_normalization(o, training=training) + o = self.activation(o) + o = self.dropout(o, training=training) + return o + + +class Postnet(keras.layers.Layer): + def __init__(self, output_filters, num_convs, **kwargs): + super(Postnet, self).__init__(**kwargs) + self.convolutions = [] + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) + for idx in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) + self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + return o + + +class Encoder(keras.layers.Layer): + def __init__(self, output_input_dim, **kwargs): + super(Encoder, self).__init__(**kwargs) + self.convolutions = [] + for idx in range(3): + self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) + self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + o = self.lstm(o) + return o + + +class Decoder(keras.layers.Layer): + #pylint: disable=unused-argument + def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, + prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, + use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): + super(Decoder, self).__init__(**kwargs) + self.frame_dim = frame_dim + self.r_init = tf.constant(r, dtype=tf.int32) + self.r = tf.constant(r, dtype=tf.int32) + self.output_dim = r * self.frame_dim + self.separate_stopnet = separate_stopnet + self.enable_tflite = enable_tflite + + # layer constants + self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) + self.stop_thresh = tf.constant(0.5, dtype=tf.float32) + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + self.prenet = Prenet(prenet_type, + prenet_dropout, + [self.prenet_dim, self.prenet_dim], + bias=False, + name='prenet') + self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', ) + self.attention_rnn_dropout = keras.layers.Dropout(0.5) + + # TODO: implement other attn options + self.attention = Attention(attn_dim=self.attn_dim, + use_loc_attn=True, + loc_attn_n_filters=32, + loc_attn_kernel_size=31, + use_windowing=False, + norm=attn_norm, + use_forward_attn=use_forward_attn, + use_trans_agent=use_trans_agent, + use_forward_attn_mask=use_forward_attn_mask, + name='attention') + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn') + self.decoder_rnn_dropout = keras.layers.Dropout(0.5) + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer') + self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer') + + + def set_max_decoder_steps(self, new_max_steps): + self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) + + def set_r(self, new_r): + self.r = tf.constant(new_r, dtype=tf.int32) + self.output_dim = self.frame_dim * new_r + + def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): + zero_frame = tf.zeros([batch_size, self.frame_dim]) + zero_context = tf.zeros([batch_size, memory_dim]) + attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + attention_states = self.attention.init_states(batch_size, memory_length) + return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states + + def step(self, prenet_next, states, + memory_seq_length=None, training=None): + _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states + attention_rnn_input = tf.concat([prenet_next, context_next], -1) + attention_rnn_output, attention_rnn_state = \ + self.attention_rnn(attention_rnn_input, + attention_rnn_state, training=training) + attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) + context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) + decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) + decoder_rnn_output, decoder_rnn_state = \ + self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) + decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) + linear_projection_input = tf.concat([decoder_rnn_output, context], -1) + output_frame = self.linear_projection(linear_projection_input, training=training) + stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) + stopnet_output = self.stopnet(stopnet_input, training=training) + output_frame = output_frame[:, :self.r * self.frame_dim] + states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) + return output_frame, stopnet_output, states, attention + + def decode(self, memory, states, frames, memory_seq_length=None): + B, _, _ = shape_list(memory) + num_iter = shape_list(frames)[1] // self.r + # init states + frame_zero = tf.expand_dims(states[0], 1) + frames = tf.concat([frame_zero, frames], axis=1) + outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) + attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) + # pre-computes + self.attention.process_values(memory) + prenet_output = self.prenet(frames, training=True) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): + prenet_next = prenet_output[:, step] + output, stop_token, states, attention = self.step(prenet_next, + states, + memory_seq_length) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions + _, memory, _, states, outputs, stop_tokens, attentions = \ + tf.while_loop(lambda *arg: True, + _body, + loop_vars=(step_count, memory, prenet_output, + states, outputs, stop_tokens, attentions), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=num_iter) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0, 2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def decode_inference(self, memory, states): + B, _, _ = shape_list(memory) + # init states + outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, attention = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag + + cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_tokens, attentions, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0, 2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def decode_inference_tflite(self, memory, states): + """Inference with TF-Lite compatibility. It assumes + batch_size is 1""" + # init states + # dynamic_shape is not supported in TFLite + outputs = tf.TensorArray(dtype=tf.float32, + size=self.max_decoder_steps, + element_shape=tf.TensorShape( + [self.output_dim]), + clear_after_read=False, + dynamic_size=False) + # stop_flags = tf.TensorArray(dtype=tf.bool, + # size=self.max_decoder_steps, + # element_shape=tf.TensorShape( + # []), + # clear_after_read=False, + # dynamic_size=False) + attentions = () + stop_tokens = () + + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, _ = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + # stop_flags = stop_flags.write(step, tf.logical_not(stop_flag)) + + outputs = outputs.write(step, tf.reshape(output, [-1])) + return step + 1, memory, states, outputs, stop_flag + + cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + step_count, memory, states, outputs, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + + outputs = outputs.stack() + outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter + outputs = tf.expand_dims(outputs, axis=[0]) + outputs = tf.transpose(outputs, [1, 0, 2]) + outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + + def call(self, memory, states, frames=None, memory_seq_length=None, training=False): + if training: + return self.decode(memory, states, frames, memory_seq_length) + if self.enable_tflite: + return self.decode_inference_tflite(memory, states) + return self.decode_inference(memory, states) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py index 9d470b09..882af517 100644 --- a/TTS/tts/tf/models/tacotron2.py +++ b/TTS/tts/tf/models/tacotron2.py @@ -1,7 +1,7 @@ import tensorflow as tf from tensorflow import keras -from TTS.tts.tf.layers.tacotron2 import Encoder, Decoder, Postnet +from TTS.tts.tf.layers.tacotron.tacotron2 import Encoder, Decoder, Postnet from TTS.tts.tf.utils.tf_utils import shape_list diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index c7a2b91a..4e1c6d43 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -94,6 +94,7 @@ def basic_turkish_cleaners(text): text = collapse_whitespace(text) return text + def english_cleaners(text): '''Pipeline for English text, including number and abbreviation expansion.''' text = convert_to_ascii(text) @@ -106,6 +107,7 @@ def english_cleaners(text): text = collapse_whitespace(text) return text + def french_cleaners(text): '''Pipeline for French text. There is no need to expand numbers, phonemizer already does that''' text = expand_abbreviations(text, lang='fr') @@ -115,6 +117,7 @@ def french_cleaners(text): text = collapse_whitespace(text) return text + def portuguese_cleaners(text): '''Basic pipeline for Portuguese text. There is no need to expand abbreviation and numbers, phonemizer already does that''' @@ -124,11 +127,13 @@ def portuguese_cleaners(text): text = collapse_whitespace(text) return text + def chinese_mandarin_cleaners(text: str) -> str: '''Basic pipeline for chinese''' text = replace_numbers_to_characters_in_text(text) return text + def phoneme_cleaners(text): '''Pipeline for phonemes mode, including number and abbreviation expansion.''' text = expand_numbers(text)