mirror of https://github.com/coqui-ai/TTS.git
106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
import datetime
|
|
import importlib
|
|
import pickle
|
|
|
|
import fsspec
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
|
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
|
state = {
|
|
"model": model.weights,
|
|
"optimizer": optimizer,
|
|
"step": current_step,
|
|
"epoch": epoch,
|
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
"r": r,
|
|
}
|
|
state.update(kwargs)
|
|
with fsspec.open(output_path, "wb") as f:
|
|
pickle.dump(state, f)
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path):
|
|
with fsspec.open(checkpoint_path, "rb") as f:
|
|
checkpoint = pickle.load(f)
|
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
|
tf_vars = model.weights
|
|
for tf_var in tf_vars:
|
|
layer_name = tf_var.name
|
|
try:
|
|
chkp_var_value = chkp_var_dict[layer_name]
|
|
except KeyError:
|
|
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
|
layer_name = f"{class_name}/{layer_name}"
|
|
chkp_var_value = chkp_var_dict[layer_name]
|
|
|
|
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
|
if "r" in checkpoint.keys():
|
|
model.decoder.set_r(checkpoint["r"])
|
|
return model
|
|
|
|
|
|
def sequence_mask(sequence_length, max_len=None):
|
|
if max_len is None:
|
|
max_len = sequence_length.max()
|
|
batch_size = sequence_length.size(0)
|
|
seq_range = np.empty([0, max_len], dtype=np.int8)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_range_expand = seq_range_expand.type_as(sequence_length)
|
|
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
|
# B x T_max
|
|
return seq_range_expand < seq_length_expand
|
|
|
|
|
|
# @tf.custom_gradient
|
|
def check_gradient(x, grad_clip):
|
|
x_normed = tf.clip_by_norm(x, grad_clip)
|
|
grad_norm = tf.norm(grad_clip)
|
|
return x_normed, grad_norm
|
|
|
|
|
|
def count_parameters(model, c):
|
|
try:
|
|
return model.count_params()
|
|
except RuntimeError:
|
|
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
|
|
input_lengths = np.random.randint(100, 129, (8,))
|
|
input_lengths[-1] = 128
|
|
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
|
|
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
|
|
mel_spec = tf.convert_to_tensor(mel_spec)
|
|
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
|
|
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
|
|
return model.count_params()
|
|
|
|
|
|
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
|
print(" > Using model: {}".format(c.model))
|
|
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
|
|
MyModel = getattr(MyModel, c.model)
|
|
if c.model.lower() in "tacotron":
|
|
raise NotImplementedError(" [!] Tacotron model is not ready.")
|
|
# tacotron2
|
|
model = MyModel(
|
|
num_chars=num_chars,
|
|
num_speakers=num_speakers,
|
|
r=c.r,
|
|
out_channels=c.audio["num_mels"],
|
|
decoder_output_dim=c.audio["num_mels"],
|
|
attn_type=c.attention_type,
|
|
attn_win=c.windowing,
|
|
attn_norm=c.attention_norm,
|
|
prenet_type=c.prenet_type,
|
|
prenet_dropout=c.prenet_dropout,
|
|
forward_attn=c.use_forward_attn,
|
|
trans_agent=c.transition_agent,
|
|
forward_attn_mask=c.forward_attn_mask,
|
|
location_attn=c.location_attn,
|
|
attn_K=c.attention_heads,
|
|
separate_stopnet=c.separate_stopnet,
|
|
bidirectional_decoder=c.bidirectional_decoder,
|
|
enable_tflite=enable_tflite,
|
|
)
|
|
return model
|