Style extract_tts_spectrogram.py

This commit is contained in:
Eren Gölge 2021-09-10 08:21:21 +00:00
parent 1de010acd4
commit 3c740d4893
1 changed files with 13 additions and 9 deletions

View File

@ -76,14 +76,14 @@ def set_filename(wav_path, out_path):
def format_data(data): def format_data(data):
# setup input data # setup input data
text_input = data['text'] text_input = data["text"]
text_lengths = data['text_lengths'] text_lengths = data["text_lengths"]
mel_input = data['mel'] mel_input = data["mel"]
mel_lengths = data['mel_lengths'] mel_lengths = data["mel_lengths"]
item_idx = data['item_idxs'] item_idx = data["item_idxs"]
d_vectors = data['d_vectors'] d_vectors = data["d_vectors"]
speaker_ids = data['speaker_ids'] speaker_ids = data["speaker_ids"]
attn_mask = data['attns'] attn_mask = data["attns"]
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
@ -132,7 +132,11 @@ def inference(
elif d_vectors is not None: elif d_vectors is not None:
speaker_c = d_vectors speaker_c = d_vectors
outputs = model.inference_with_MAS( outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids} text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
) )
model_output = outputs["model_outputs"] model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy() model_output = model_output.transpose(1, 2).detach().cpu().numpy()