# %%
import sys
sys.path.append('/home/erogol/Projects')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# %%
import argparse
import numpy as np
import torch
import tensorflow as tf
from fuzzywuzzy import fuzz

from TTS.utils.text.symbols import phonemes, symbols
from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config
from TTS.tf.models.tacotron2 import Tacotron2
from TTS.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
from TTS.tf.utils.generic_utils import save_checkpoint

parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path',
                    type=str,
                    help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
                    type=str,
                    help='Path to config file of torch model.')
parser.add_argument('--output_path',
                    type=str,
                    help='path to save TF model weights.')
args = parser.parse_args()

# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0

# init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path,
                        map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)

# init tf model
model_tf = Tacotron2(num_chars=num_chars,
                     num_speakers=num_speakers,
                     r=model.decoder.r,
                     postnet_output_dim=c.audio['num_mels'],
                     decoder_output_dim=c.audio['num_mels'],
                     attn_type=c.attention_type,
                     attn_win=c.windowing,
                     attn_norm=c.attention_norm,
                     prenet_type=c.prenet_type,
                     prenet_dropout=c.prenet_dropout,
                     forward_attn=c.use_forward_attn,
                     trans_agent=c.transition_agent,
                     forward_attn_mask=c.forward_attn_mask,
                     location_attn=c.location_attn,
                     attn_K=c.attention_heads,
                     separate_stopnet=c.separate_stopnet,
                     bidirectional_decoder=c.bidirectional_decoder)

# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
var_map = [
    ('tacotron2/embedding/embeddings:0', 'embedding.weight'),
    ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
     'encoder.lstm.weight_ih_l0'),
    ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
     'encoder.lstm.weight_hh_l0'),
    ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
     'encoder.lstm.weight_ih_l0_reverse'),
    ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
     'encoder.lstm.weight_hh_l0_reverse'),
    ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
     ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
    ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
     ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
    ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
    ('decoder/linear_projection/kernel:0',
     'decoder.linear_projection.linear_layer.weight'),
    ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
]

# %%
# get tf_model graph
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
mel_pred = model_tf(input_ids, training=False)

# get tf variables
tf_vars = model_tf.weights

# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
for tf_name in tf_var_names:
    # skip re-mapped layer names
    if tf_name in [name[0] for name in var_map]:
        continue
    tf_name_edited = convert_tf_name(tf_name)
    ratios = [
        fuzz.ratio(torch_name, tf_name_edited)
        for torch_name in torch_var_names
    ]
    max_idx = np.argmax(ratios)
    matching_name = torch_var_names[max_idx]
    del torch_var_names[max_idx]
    var_map.append((tf_name, matching_name))

# %%
# print variable match
from pprint import pprint
pprint(var_map)
pprint(torch_var_names)

# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)

# Compare TF and TORCH models
# %%
# check embedding outputs
model.eval()
input_ids = torch.randint(0, 24, (1, 128)).long()

o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() -
           o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
                                           o_tf.numpy()).sum()

# compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1, 2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5

#pylint: disable=redefined-builtin
# compare decoder.attention_rnn
inp = torch.rand([1, 768])
inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None)  #pylint: disable=protected-access
output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
                                                         states[2],
                                                         training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5

query = output
inputs = torch.rand([1, 128, 512])
query_tf = query.detach().numpy()
inputs_tf = inputs.numpy()

# compare decoder.attention
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(
    query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)

attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)

assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
assert compare_torch_tf(context, context_tf) < 1e-5

# compare decoder.decoder_rnn
input = torch.rand([1, 1536])
input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None)  #pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(
    input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
                                                       states[3],
                                                       training=False)
assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5

# compare decoder.linear_projection
input = torch.rand([1, 1536])
input_tf = input.numpy()
output = model.decoder.linear_projection(input)
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
assert compare_torch_tf(output, output_tf) < 1e-5

# compare decoder outputs
model.decoder.max_decoder_steps = 100
model_tf.decoder.set_max_decoder_steps(100)
output, align, stop = model.decoder.inference(oo_en)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4

# compare the whole model output
outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
assert compare_torch_tf(outputs_torch[2][:, 50, :],
                        outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4

# %%
# save tf model
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
                checkpoint['r'], args.output_path)
print(' > Model conversion is successfully completed :).')