mirror of https://github.com/coqui-ai/TTS.git
style update #3
This commit is contained in:
parent
18d9ec8036
commit
87ee6ceb57
|
@ -25,12 +25,11 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
import pandas
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from absl import logging
|
from absl import logging
|
||||||
|
|
||||||
import pandas
|
|
||||||
|
|
||||||
gfile = tf.compat.v1.gfile
|
gfile = tf.compat.v1.gfile
|
||||||
|
|
||||||
SUBSETS = {
|
SUBSETS = {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
cimport numpy as np
|
|
||||||
cimport cython
|
cimport cython
|
||||||
|
cimport numpy as np
|
||||||
|
|
||||||
from cython.parallel import prange
|
from cython.parallel import prange
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,12 @@ import random
|
||||||
from statistics import StatisticsError, mean, median, mode, stdev
|
from statistics import StatisticsError, mean, median, mode, stdev
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from text.cmudict import CMUDict
|
from text.cmudict import CMUDict
|
||||||
|
|
||||||
|
|
||||||
def get_audio_seconds(frames):
|
def get_audio_seconds(frames):
|
||||||
return (frames*12.5)/1000
|
return (frames * 12.5) / 1000
|
||||||
|
|
||||||
|
|
||||||
def append_data_statistics(meta_data):
|
def append_data_statistics(meta_data):
|
||||||
|
@ -29,9 +28,7 @@ def append_data_statistics(meta_data):
|
||||||
median_audio_len = median(audio_len_list)
|
median_audio_len = median(audio_len_list)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
std = stdev(
|
std = stdev(d["audio_len"] for d in data)
|
||||||
d["audio_len"] for d in data
|
|
||||||
)
|
|
||||||
except StatisticsError:
|
except StatisticsError:
|
||||||
std = 0
|
std = 0
|
||||||
|
|
||||||
|
@ -46,24 +43,22 @@ def process_meta_data(path):
|
||||||
meta_data = {}
|
meta_data = {}
|
||||||
|
|
||||||
# load meta data
|
# load meta data
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
data = csv.reader(f, delimiter='|')
|
data = csv.reader(f, delimiter="|")
|
||||||
for row in data:
|
for row in data:
|
||||||
frames = int(row[2])
|
frames = int(row[2])
|
||||||
utt = row[3]
|
utt = row[3]
|
||||||
audio_len = get_audio_seconds(frames)
|
audio_len = get_audio_seconds(frames)
|
||||||
char_count = len(utt)
|
char_count = len(utt)
|
||||||
if not meta_data.get(char_count):
|
if not meta_data.get(char_count):
|
||||||
meta_data[char_count] = {
|
meta_data[char_count] = {"data": []}
|
||||||
"data": []
|
|
||||||
}
|
|
||||||
|
|
||||||
meta_data[char_count]["data"].append(
|
meta_data[char_count]["data"].append(
|
||||||
{
|
{
|
||||||
"utt": utt,
|
"utt": utt,
|
||||||
"frames": frames,
|
"frames": frames,
|
||||||
"audio_len": audio_len,
|
"audio_len": audio_len,
|
||||||
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3])
|
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,30 +69,30 @@ def process_meta_data(path):
|
||||||
|
|
||||||
def get_data_points(meta_data):
|
def get_data_points(meta_data):
|
||||||
x = meta_data
|
x = meta_data
|
||||||
y_avg = [meta_data[d]['mean'] for d in meta_data]
|
y_avg = [meta_data[d]["mean"] for d in meta_data]
|
||||||
y_mode = [meta_data[d]['mode'] for d in meta_data]
|
y_mode = [meta_data[d]["mode"] for d in meta_data]
|
||||||
y_median = [meta_data[d]['median'] for d in meta_data]
|
y_median = [meta_data[d]["median"] for d in meta_data]
|
||||||
y_std = [meta_data[d]['std'] for d in meta_data]
|
y_std = [meta_data[d]["std"] for d in meta_data]
|
||||||
y_num_samples = [len(meta_data[d]['data']) for d in meta_data]
|
y_num_samples = [len(meta_data[d]["data"]) for d in meta_data]
|
||||||
return {
|
return {
|
||||||
"x": x,
|
"x": x,
|
||||||
"y_avg": y_avg,
|
"y_avg": y_avg,
|
||||||
"y_mode": y_mode,
|
"y_mode": y_mode,
|
||||||
"y_median": y_median,
|
"y_median": y_median,
|
||||||
"y_std": y_std,
|
"y_std": y_std,
|
||||||
"y_num_samples": y_num_samples
|
"y_num_samples": y_num_samples,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_training(file_path, meta_data):
|
def save_training(file_path, meta_data):
|
||||||
rows = []
|
rows = []
|
||||||
for char_cnt in meta_data:
|
for char_cnt in meta_data:
|
||||||
data = meta_data[char_cnt]['data']
|
data = meta_data[char_cnt]["data"]
|
||||||
for d in data:
|
for d in data:
|
||||||
rows.append(d['row'] + "\n")
|
rows.append(d["row"] + "\n")
|
||||||
|
|
||||||
random.shuffle(rows)
|
random.shuffle(rows)
|
||||||
with open(file_path, 'w+') as f:
|
with open(file_path, "w+") as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
f.write(row)
|
f.write(row)
|
||||||
|
|
||||||
|
@ -108,15 +103,15 @@ def plot(meta_data, save_path=None):
|
||||||
save = True
|
save = True
|
||||||
|
|
||||||
graph_data = get_data_points(meta_data)
|
graph_data = get_data_points(meta_data)
|
||||||
x = graph_data['x']
|
x = graph_data["x"]
|
||||||
y_avg = graph_data['y_avg']
|
y_avg = graph_data["y_avg"]
|
||||||
y_std = graph_data['y_std']
|
y_std = graph_data["y_std"]
|
||||||
y_mode = graph_data['y_mode']
|
y_mode = graph_data["y_mode"]
|
||||||
y_median = graph_data['y_median']
|
y_median = graph_data["y_median"]
|
||||||
y_num_samples = graph_data['y_num_samples']
|
y_num_samples = graph_data["y_num_samples"]
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(x, y_avg, 'ro')
|
plt.plot(x, y_avg, "ro")
|
||||||
plt.xlabel("character lengths", fontsize=30)
|
plt.xlabel("character lengths", fontsize=30)
|
||||||
plt.ylabel("avg seconds", fontsize=30)
|
plt.ylabel("avg seconds", fontsize=30)
|
||||||
if save:
|
if save:
|
||||||
|
@ -124,7 +119,7 @@ def plot(meta_data, save_path=None):
|
||||||
plt.savefig(os.path.join(save_path, name))
|
plt.savefig(os.path.join(save_path, name))
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(x, y_mode, 'ro')
|
plt.plot(x, y_mode, "ro")
|
||||||
plt.xlabel("character lengths", fontsize=30)
|
plt.xlabel("character lengths", fontsize=30)
|
||||||
plt.ylabel("mode seconds", fontsize=30)
|
plt.ylabel("mode seconds", fontsize=30)
|
||||||
if save:
|
if save:
|
||||||
|
@ -132,7 +127,7 @@ def plot(meta_data, save_path=None):
|
||||||
plt.savefig(os.path.join(save_path, name))
|
plt.savefig(os.path.join(save_path, name))
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(x, y_median, 'ro')
|
plt.plot(x, y_median, "ro")
|
||||||
plt.xlabel("character lengths", fontsize=30)
|
plt.xlabel("character lengths", fontsize=30)
|
||||||
plt.ylabel("median seconds", fontsize=30)
|
plt.ylabel("median seconds", fontsize=30)
|
||||||
if save:
|
if save:
|
||||||
|
@ -140,7 +135,7 @@ def plot(meta_data, save_path=None):
|
||||||
plt.savefig(os.path.join(save_path, name))
|
plt.savefig(os.path.join(save_path, name))
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(x, y_std, 'ro')
|
plt.plot(x, y_std, "ro")
|
||||||
plt.xlabel("character lengths", fontsize=30)
|
plt.xlabel("character lengths", fontsize=30)
|
||||||
plt.ylabel("standard deviation", fontsize=30)
|
plt.ylabel("standard deviation", fontsize=30)
|
||||||
if save:
|
if save:
|
||||||
|
@ -148,7 +143,7 @@ def plot(meta_data, save_path=None):
|
||||||
plt.savefig(os.path.join(save_path, name))
|
plt.savefig(os.path.join(save_path, name))
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(x, y_num_samples, 'ro')
|
plt.plot(x, y_num_samples, "ro")
|
||||||
plt.xlabel("character lengths", fontsize=30)
|
plt.xlabel("character lengths", fontsize=30)
|
||||||
plt.ylabel("number of samples", fontsize=30)
|
plt.ylabel("number of samples", fontsize=30)
|
||||||
if save:
|
if save:
|
||||||
|
@ -161,8 +156,8 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
||||||
|
|
||||||
phonemes = {}
|
phonemes = {}
|
||||||
|
|
||||||
with open(train_path, 'r') as f:
|
with open(train_path, "r") as f:
|
||||||
data = csv.reader(f, delimiter='|')
|
data = csv.reader(f, delimiter="|")
|
||||||
phonemes["None"] = 0
|
phonemes["None"] = 0
|
||||||
for row in data:
|
for row in data:
|
||||||
words = row[3].split()
|
words = row[3].split()
|
||||||
|
@ -194,15 +189,12 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--train_file_path', required=True,
|
"--train_file_path",
|
||||||
help='this is the path to the train.txt file that the preprocess.py script creates'
|
required=True,
|
||||||
)
|
help="this is the path to the train.txt file that the preprocess.py script creates",
|
||||||
parser.add_argument(
|
|
||||||
'--save_to', help='path to save charts of data to'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--cmu_dict_path', help='give cmudict-0.7b to see phoneme distribution'
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--save_to", help="path to save charts of data to")
|
||||||
|
parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
meta_data = process_meta_data(args.train_file_path)
|
meta_data = process_meta_data(args.train_file_path)
|
||||||
plt.rcParams["figure.figsize"] = (10, 5)
|
plt.rcParams["figure.figsize"] = (10, 5)
|
||||||
|
@ -213,5 +205,6 @@ def main():
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -26,3 +26,8 @@ exclude = '''
|
||||||
# the root of the project
|
# the root of the project
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
line_length = 120
|
||||||
|
profile = "black"
|
||||||
|
multi_line_output = 3
|
|
@ -18,10 +18,12 @@ bokeh==1.4.0
|
||||||
pysbd
|
pysbd
|
||||||
# pyworld
|
# pyworld
|
||||||
soundfile
|
soundfile
|
||||||
nose==1.3.7
|
|
||||||
cardboardlint==1.3.0
|
|
||||||
pylint==2.5.3
|
|
||||||
gdown
|
gdown
|
||||||
umap-learn==0.4.6
|
umap-learn==0.4.6
|
||||||
cython
|
cython
|
||||||
pyyaml
|
pyyaml
|
||||||
|
# quality and style
|
||||||
|
nose
|
||||||
|
black
|
||||||
|
isort
|
||||||
|
pylint==2.7.4
|
|
@ -10,7 +10,7 @@ OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
os.makedirs(OUT_PATH, exist_ok=True)
|
os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
conf = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
@ -20,7 +20,7 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap = AudioProcessor(**conf.audio)
|
self.ap = AudioProcessor(**conf.audio)
|
||||||
|
|
||||||
def test_audio_synthesis(self):
|
def test_audio_synthesis(self):
|
||||||
""" 1. load wav
|
"""1. load wav
|
||||||
2. set normalization parameters
|
2. set normalization parameters
|
||||||
3. extract mel-spec
|
3. extract mel-spec
|
||||||
4. invert to wav and save the output
|
4. invert to wav and save the output
|
||||||
|
@ -35,23 +35,24 @@ class TestAudio(unittest.TestCase):
|
||||||
wav = self.ap.load_wav(WAV_FILE)
|
wav = self.ap.load_wav(WAV_FILE)
|
||||||
mel = self.ap.melspectrogram(wav)
|
mel = self.ap.melspectrogram(wav)
|
||||||
wav_ = self.ap.inv_melspectrogram(mel)
|
wav_ = self.ap.inv_melspectrogram(mel)
|
||||||
file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\
|
file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav".format(
|
||||||
.format(max_norm, signal_norm, symmetric_norm, clip_norm)
|
max_norm, signal_norm, symmetric_norm, clip_norm
|
||||||
|
)
|
||||||
print(" | > Creating wav file at : ", file_name)
|
print(" | > Creating wav file at : ", file_name)
|
||||||
self.ap.save_wav(wav_, OUT_PATH + file_name)
|
self.ap.save_wav(wav_, OUT_PATH + file_name)
|
||||||
|
|
||||||
# maxnorm = 1.0
|
# maxnorm = 1.0
|
||||||
_test(1., False, False, False)
|
_test(1.0, False, False, False)
|
||||||
_test(1., True, False, False)
|
_test(1.0, True, False, False)
|
||||||
_test(1., True, True, False)
|
_test(1.0, True, True, False)
|
||||||
_test(1., True, False, True)
|
_test(1.0, True, False, True)
|
||||||
_test(1., True, True, True)
|
_test(1.0, True, True, True)
|
||||||
# maxnorm = 4.0
|
# maxnorm = 4.0
|
||||||
_test(4., False, False, False)
|
_test(4.0, False, False, False)
|
||||||
_test(4., True, False, False)
|
_test(4.0, True, False, False)
|
||||||
_test(4., True, True, False)
|
_test(4.0, True, True, False)
|
||||||
_test(4., True, False, True)
|
_test(4.0, True, False, True)
|
||||||
_test(4., True, True, True)
|
_test(4.0, True, True, True)
|
||||||
|
|
||||||
def test_normalize(self):
|
def test_normalize(self):
|
||||||
"""Check normalization and denormalization for range values and consistency """
|
"""Check normalization and denormalization for range values and consistency """
|
||||||
|
@ -67,7 +68,9 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.clip_norm = False
|
self.ap.clip_norm = False
|
||||||
self.ap.max_norm = 4.0
|
self.ap.max_norm = 4.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
# check value range
|
# check value range
|
||||||
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
||||||
|
@ -81,8 +84,9 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.clip_norm = True
|
self.ap.clip_norm = True
|
||||||
self.ap.max_norm = 4.0
|
self.ap.max_norm = 4.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
|
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
# check value range
|
# check value range
|
||||||
|
@ -97,13 +101,14 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.clip_norm = False
|
self.ap.clip_norm = False
|
||||||
self.ap.max_norm = 4.0
|
self.ap.max_norm = 4.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
|
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
# check value range
|
# check value range
|
||||||
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
|
||||||
assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min() #pylint: disable=invalid-unary-operand-type
|
assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min() # pylint: disable=invalid-unary-operand-type
|
||||||
assert x_norm.min() <= 0, x_norm.min()
|
assert x_norm.min() <= 0, x_norm.min()
|
||||||
# check denorm.
|
# check denorm.
|
||||||
x_ = self.ap.denormalize(x_norm)
|
x_ = self.ap.denormalize(x_norm)
|
||||||
|
@ -114,13 +119,14 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.clip_norm = True
|
self.ap.clip_norm = True
|
||||||
self.ap.max_norm = 4.0
|
self.ap.max_norm = 4.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
|
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
# check value range
|
# check value range
|
||||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||||
assert x_norm.min() >= -self.ap.max_norm, x_norm.min() #pylint: disable=invalid-unary-operand-type
|
assert x_norm.min() >= -self.ap.max_norm, x_norm.min() # pylint: disable=invalid-unary-operand-type
|
||||||
assert x_norm.min() <= 0, x_norm.min()
|
assert x_norm.min() <= 0, x_norm.min()
|
||||||
# check denorm.
|
# check denorm.
|
||||||
x_ = self.ap.denormalize(x_norm)
|
x_ = self.ap.denormalize(x_norm)
|
||||||
|
@ -130,8 +136,9 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.symmetric_norm = False
|
self.ap.symmetric_norm = False
|
||||||
self.ap.max_norm = 1.0
|
self.ap.max_norm = 1.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
|
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||||
|
@ -143,22 +150,23 @@ class TestAudio(unittest.TestCase):
|
||||||
self.ap.symmetric_norm = True
|
self.ap.symmetric_norm = True
|
||||||
self.ap.max_norm = 1.0
|
self.ap.max_norm = 1.0
|
||||||
x_norm = self.ap.normalize(x)
|
x_norm = self.ap.normalize(x)
|
||||||
print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}")
|
print(
|
||||||
|
f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}"
|
||||||
|
)
|
||||||
|
|
||||||
assert (x_old - x).sum() == 0
|
assert (x_old - x).sum() == 0
|
||||||
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
assert x_norm.max() <= self.ap.max_norm, x_norm.max()
|
||||||
assert x_norm.min() >= -self.ap.max_norm, x_norm.min() #pylint: disable=invalid-unary-operand-type
|
assert x_norm.min() >= -self.ap.max_norm, x_norm.min() # pylint: disable=invalid-unary-operand-type
|
||||||
assert x_norm.min() < 0, x_norm.min()
|
assert x_norm.min() < 0, x_norm.min()
|
||||||
x_ = self.ap.denormalize(x_norm)
|
x_ = self.ap.denormalize(x_norm)
|
||||||
assert (x - x_).sum() < 1e-3
|
assert (x - x_).sum() < 1e-3
|
||||||
|
|
||||||
def test_scaler(self):
|
def test_scaler(self):
|
||||||
scaler_stats_path = os.path.join(get_tests_input_path(), 'scale_stats.npy')
|
scaler_stats_path = os.path.join(get_tests_input_path(), "scale_stats.npy")
|
||||||
conf.audio['stats_path'] = scaler_stats_path
|
conf.audio["stats_path"] = scaler_stats_path
|
||||||
conf.audio['preemphasis'] = 0.0
|
conf.audio["preemphasis"] = 0.0
|
||||||
conf.audio['do_trim_silence'] = True
|
conf.audio["do_trim_silence"] = True
|
||||||
conf.audio['signal_norm'] = True
|
conf.audio["signal_norm"] = True
|
||||||
|
|
||||||
ap = AudioProcessor(**conf.audio)
|
ap = AudioProcessor(**conf.audio)
|
||||||
mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path)
|
mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path)
|
||||||
|
|
|
@ -9,99 +9,99 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def test_encoder():
|
def test_encoder():
|
||||||
input_dummy = torch.rand(8, 14, 37).to(device)
|
input_dummy = torch.rand(8, 14, 37).to(device)
|
||||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
input_lengths = torch.randint(31, 37, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 37
|
input_lengths[-1] = 37
|
||||||
input_mask = torch.unsqueeze(
|
input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
|
||||||
# relative positional transformer encoder
|
# relative positional transformer encoder
|
||||||
layer = Encoder(out_channels=11,
|
layer = Encoder(
|
||||||
|
out_channels=11,
|
||||||
in_hidden_channels=14,
|
in_hidden_channels=14,
|
||||||
encoder_type='relative_position_transformer',
|
encoder_type="relative_position_transformer",
|
||||||
encoder_params={
|
encoder_params={
|
||||||
'hidden_channels_ffn': 768,
|
"hidden_channels_ffn": 768,
|
||||||
'num_heads': 2,
|
"num_heads": 2,
|
||||||
"kernel_size": 3,
|
"kernel_size": 3,
|
||||||
"dropout_p": 0.1,
|
"dropout_p": 0.1,
|
||||||
"num_layers": 6,
|
"num_layers": 6,
|
||||||
"rel_attn_window_size": 4,
|
"rel_attn_window_size": 4,
|
||||||
"input_length": None
|
"input_length": None,
|
||||||
}).to(device)
|
},
|
||||||
|
).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
# residual conv bn encoder
|
# residual conv bn encoder
|
||||||
layer = Encoder(out_channels=11,
|
layer = Encoder(
|
||||||
|
out_channels=11,
|
||||||
in_hidden_channels=14,
|
in_hidden_channels=14,
|
||||||
encoder_type='residual_conv_bn',
|
encoder_type="residual_conv_bn",
|
||||||
encoder_params={
|
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||||
"kernel_size": 4,
|
).to(device)
|
||||||
"dilations": 4 * [1, 2, 4] + [1],
|
|
||||||
"num_conv_blocks": 2,
|
|
||||||
"num_res_blocks": 13
|
|
||||||
}).to(device)
|
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
# FFTransformer encoder
|
# FFTransformer encoder
|
||||||
layer = Encoder(out_channels=14,
|
layer = Encoder(
|
||||||
|
out_channels=14,
|
||||||
in_hidden_channels=14,
|
in_hidden_channels=14,
|
||||||
encoder_type='fftransformer',
|
encoder_type="fftransformer",
|
||||||
encoder_params={
|
encoder_params={"hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1},
|
||||||
"hidden_channels_ffn": 31,
|
).to(device)
|
||||||
"num_heads": 2,
|
|
||||||
"num_layers": 2,
|
|
||||||
"dropout_p": 0.1
|
|
||||||
}).to(device)
|
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 14, 37]
|
assert list(output.shape) == [8, 14, 37]
|
||||||
|
|
||||||
|
|
||||||
def test_decoder():
|
def test_decoder():
|
||||||
input_dummy = torch.rand(8, 128, 37).to(device)
|
input_dummy = torch.rand(8, 128, 37).to(device)
|
||||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
input_lengths = torch.randint(31, 37, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 37
|
input_lengths[-1] = 37
|
||||||
|
|
||||||
input_mask = torch.unsqueeze(
|
input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
|
||||||
# residual bn conv decoder
|
# residual bn conv decoder
|
||||||
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
# transformer decoder
|
# transformer decoder
|
||||||
layer = Decoder(out_channels=11,
|
layer = Decoder(
|
||||||
|
out_channels=11,
|
||||||
in_hidden_channels=128,
|
in_hidden_channels=128,
|
||||||
decoder_type='relative_position_transformer',
|
decoder_type="relative_position_transformer",
|
||||||
decoder_params={
|
decoder_params={
|
||||||
'hidden_channels_ffn': 128,
|
"hidden_channels_ffn": 128,
|
||||||
'num_heads': 2,
|
"num_heads": 2,
|
||||||
"kernel_size": 3,
|
"kernel_size": 3,
|
||||||
"dropout_p": 0.1,
|
"dropout_p": 0.1,
|
||||||
"num_layers": 8,
|
"num_layers": 8,
|
||||||
"rel_attn_window_size": 4,
|
"rel_attn_window_size": 4,
|
||||||
"input_length": None
|
"input_length": None,
|
||||||
}).to(device)
|
},
|
||||||
|
).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
# wavenet decoder
|
# wavenet decoder
|
||||||
layer = Decoder(out_channels=11,
|
layer = Decoder(
|
||||||
|
out_channels=11,
|
||||||
in_hidden_channels=128,
|
in_hidden_channels=128,
|
||||||
decoder_type='wavenet',
|
decoder_type="wavenet",
|
||||||
decoder_params={
|
decoder_params={
|
||||||
"num_blocks": 12,
|
"num_blocks": 12,
|
||||||
"hidden_channels": 192,
|
"hidden_channels": 192,
|
||||||
"kernel_size": 5,
|
"kernel_size": 5,
|
||||||
"dilation_rate": 1,
|
"dilation_rate": 1,
|
||||||
"num_layers": 4,
|
"num_layers": 4,
|
||||||
"dropout_p": 0.05
|
"dropout_p": 0.05,
|
||||||
}).to(device)
|
},
|
||||||
|
).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
# FFTransformer decoder
|
# FFTransformer decoder
|
||||||
layer = Decoder(out_channels=11,
|
layer = Decoder(
|
||||||
|
out_channels=11,
|
||||||
in_hidden_channels=128,
|
in_hidden_channels=128,
|
||||||
decoder_type='fftransformer',
|
decoder_type="fftransformer",
|
||||||
decoder_params={
|
decoder_params={
|
||||||
'hidden_channels_ffn': 31,
|
"hidden_channels_ffn": 31,
|
||||||
'num_heads': 2,
|
"num_heads": 2,
|
||||||
"dropout_p": 0.1,
|
"dropout_p": 0.1,
|
||||||
"num_layers": 2,
|
"num_layers": 2,
|
||||||
}).to(device)
|
},
|
||||||
|
).to(device)
|
||||||
output = layer(input_dummy, input_mask)
|
output = layer(input_dummy, input_mask)
|
||||||
assert list(output.shape) == [8, 11, 37]
|
assert list(output.shape) == [8, 11, 37]
|
||||||
|
|
|
@ -11,13 +11,13 @@ from TTS.tts.models.glow_tts import GlowTTS
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
c = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
@ -32,11 +32,11 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
criterion = GlowTTSLoss()
|
criterion = GlowTTSLoss()
|
||||||
|
|
||||||
|
@ -47,27 +47,28 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
hidden_channels_dec=48,
|
hidden_channels_dec=48,
|
||||||
hidden_channels_dp=32,
|
hidden_channels_dp=32,
|
||||||
out_channels=80,
|
out_channels=80,
|
||||||
encoder_type='rel_pos_transformer',
|
encoder_type="rel_pos_transformer",
|
||||||
encoder_params={
|
encoder_params={
|
||||||
'kernel_size': 3,
|
"kernel_size": 3,
|
||||||
'dropout_p': 0.1,
|
"dropout_p": 0.1,
|
||||||
'num_layers': 6,
|
"num_layers": 6,
|
||||||
'num_heads': 2,
|
"num_heads": 2,
|
||||||
'hidden_channels_ffn': 16, # 4 times the hidden_channels
|
"hidden_channels_ffn": 16, # 4 times the hidden_channels
|
||||||
'input_length': None
|
"input_length": None,
|
||||||
},
|
},
|
||||||
use_encoder_prenet=True,
|
use_encoder_prenet=True,
|
||||||
num_flow_blocks_dec=12,
|
num_flow_blocks_dec=12,
|
||||||
kernel_size_dec=5,
|
kernel_size_dec=5,
|
||||||
dilation_rate=1,
|
dilation_rate=1,
|
||||||
num_block_layers=4,
|
num_block_layers=4,
|
||||||
dropout_p_dec=0.,
|
dropout_p_dec=0.0,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
c_in_channels=0,
|
c_in_channels=0,
|
||||||
num_splits=4,
|
num_splits=4,
|
||||||
num_squeeze=1,
|
num_squeeze=1,
|
||||||
sigmoid_scale=False,
|
sigmoid_scale=False,
|
||||||
mean_only=False).to(device)
|
mean_only=False,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# reference model to compare model weights
|
# reference model to compare model weights
|
||||||
model_ref = GlowTTS(
|
model_ref = GlowTTS(
|
||||||
|
@ -76,38 +77,37 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
hidden_channels_dec=48,
|
hidden_channels_dec=48,
|
||||||
hidden_channels_dp=32,
|
hidden_channels_dp=32,
|
||||||
out_channels=80,
|
out_channels=80,
|
||||||
encoder_type='rel_pos_transformer',
|
encoder_type="rel_pos_transformer",
|
||||||
encoder_params={
|
encoder_params={
|
||||||
'kernel_size': 3,
|
"kernel_size": 3,
|
||||||
'dropout_p': 0.1,
|
"dropout_p": 0.1,
|
||||||
'num_layers': 6,
|
"num_layers": 6,
|
||||||
'num_heads': 2,
|
"num_heads": 2,
|
||||||
'hidden_channels_ffn': 16, # 4 times the hidden_channels
|
"hidden_channels_ffn": 16, # 4 times the hidden_channels
|
||||||
'input_length': None
|
"input_length": None,
|
||||||
},
|
},
|
||||||
use_encoder_prenet=True,
|
use_encoder_prenet=True,
|
||||||
num_flow_blocks_dec=12,
|
num_flow_blocks_dec=12,
|
||||||
kernel_size_dec=5,
|
kernel_size_dec=5,
|
||||||
dilation_rate=1,
|
dilation_rate=1,
|
||||||
num_block_layers=4,
|
num_block_layers=4,
|
||||||
dropout_p_dec=0.,
|
dropout_p_dec=0.0,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
c_in_channels=0,
|
c_in_channels=0,
|
||||||
num_splits=4,
|
num_splits=4,
|
||||||
num_squeeze=1,
|
num_squeeze=1,
|
||||||
sigmoid_scale=False,
|
sigmoid_scale=False,
|
||||||
mean_only=False).to(device)
|
mean_only=False,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for GlowTTS model:%s" %
|
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
|
|
||||||
# pass the state to ref model
|
# pass the state to ref model
|
||||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
@ -115,18 +115,17 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
)
|
||||||
o_dur_log, o_total_dur, input_lengths)
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths)
|
||||||
loss = loss_dict['loss']
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
assert (param != param_ref).any(
|
count, param.shape, param, param_ref
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
)
|
||||||
count, param.shape, param, param_ref)
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
|
@ -10,7 +10,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
def test_in_out(self): #pylint: disable=no-self-use
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
layer = Prenet(128, out_features=[256, 128])
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
dummy_input = T.rand(4, 128)
|
dummy_input = T.rand(4, 128)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
class CBHGTests(unittest.TestCase):
|
class CBHGTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
#pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
layer = self.cbhg = CBHG(
|
layer = self.cbhg = CBHG(
|
||||||
128,
|
128,
|
||||||
K=8,
|
K=8,
|
||||||
|
@ -30,7 +30,8 @@ class CBHGTests(unittest.TestCase):
|
||||||
conv_projections=[160, 128],
|
conv_projections=[160, 128],
|
||||||
highway_features=80,
|
highway_features=80,
|
||||||
gru_features=80,
|
gru_features=80,
|
||||||
num_highways=4)
|
num_highways=4,
|
||||||
|
)
|
||||||
# B x D x T
|
# B x D x T
|
||||||
dummy_input = T.rand(4, 128, 8)
|
dummy_input = T.rand(4, 128, 8)
|
||||||
|
|
||||||
|
@ -53,26 +54,27 @@ class DecoderTests(unittest.TestCase):
|
||||||
attn_norm="sigmoid",
|
attn_norm="sigmoid",
|
||||||
attn_K=5,
|
attn_K=5,
|
||||||
attn_type="original",
|
attn_type="original",
|
||||||
prenet_type='original',
|
prenet_type="original",
|
||||||
prenet_dropout=True,
|
prenet_dropout=True,
|
||||||
forward_attn=True,
|
forward_attn=True,
|
||||||
trans_agent=True,
|
trans_agent=True,
|
||||||
forward_attn_mask=True,
|
forward_attn_mask=True,
|
||||||
location_attn=True,
|
location_attn=True,
|
||||||
separate_stopnet=True)
|
separate_stopnet=True,
|
||||||
|
)
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
output, alignment, stop_tokens = layer(
|
output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None)
|
||||||
dummy_input, dummy_memory, mask=None)
|
|
||||||
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 80, "size not {}".format(output.shape[1])
|
assert output.shape[1] == 80, "size not {}".format(output.shape[1])
|
||||||
assert output.shape[2] == 2, "size not {}".format(output.shape[2])
|
assert output.shape[2] == 2, "size not {}".format(output.shape[2])
|
||||||
assert stop_tokens.shape[0] == 4
|
assert stop_tokens.shape[0] == 4
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
def test_in_out(self): #pylint: disable=no-self-use
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.rand(4, 8, 128)
|
dummy_input = T.rand(4, 8, 128)
|
||||||
|
|
||||||
|
@ -85,7 +87,7 @@ class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
def test_in_out(self): #pylint: disable=no-self-use
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
# test input == target
|
# test input == target
|
||||||
layer = L1LossMasked(seq_len_norm=False)
|
layer = L1LossMasked(seq_len_norm=False)
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
@ -105,16 +107,14 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
assert output.item() == 1.0, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
dummy_target = dummy_input.detach()
|
dummy_target = dummy_input.detach()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
@ -138,22 +138,20 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
dummy_target = dummy_input.detach()
|
dummy_target = dummy_input.detach()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
|
||||||
class SSIMLossTests(unittest.TestCase):
|
class SSIMLossTests(unittest.TestCase):
|
||||||
def test_in_out(self): #pylint: disable=no-self-use
|
def test_in_out(self): # pylint: disable=no-self-use
|
||||||
# test input == target
|
# test input == target
|
||||||
layer = SSIMLoss()
|
layer = SSIMLoss()
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
@ -173,16 +171,14 @@ class SSIMLossTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
dummy_target = dummy_input.detach()
|
dummy_target = dummy_input.detach()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
||||||
|
@ -206,15 +202,13 @@ class SSIMLossTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item())
|
||||||
|
|
||||||
dummy_input = T.rand(4, 8, 128).float()
|
dummy_input = T.rand(4, 8, 128).float()
|
||||||
dummy_target = dummy_input.detach()
|
dummy_target = dummy_input.detach()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = (
|
mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 0, "0 vs {}".format(output.item())
|
assert output.item() == 0, "0 vs {}".format(output.item())
|
||||||
|
|
|
@ -12,11 +12,11 @@ from TTS.tts.datasets.preprocess import ljspeech
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
||||||
os.makedirs(OUTPATH, exist_ok=True)
|
os.makedirs(OUTPATH, exist_ok=True)
|
||||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
c = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
ok_ljspeech = os.path.exists(c.data_path)
|
ok_ljspeech = os.path.exists(c.data_path)
|
||||||
|
|
||||||
DATA_EXIST = True
|
DATA_EXIST = True
|
||||||
|
@ -33,25 +33,27 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
self.ap = AudioProcessor(**c.audio)
|
self.ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
def _create_dataloader(self, batch_size, r, bgs):
|
def _create_dataloader(self, batch_size, r, bgs):
|
||||||
items = ljspeech(c.data_path, 'metadata.csv')
|
items = ljspeech(c.data_path, "metadata.csv")
|
||||||
dataset = TTSDataset.MyDataset(
|
dataset = TTSDataset.MyDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=True,
|
compute_linear_spec=True,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
meta_data=items,
|
meta_data=items,
|
||||||
tp=c.characters if 'characters' in c.keys() else None,
|
tp=c.characters if "characters" in c.keys() else None,
|
||||||
batch_group_size=bgs,
|
batch_group_size=bgs,
|
||||||
min_seq_len=c.min_seq_len,
|
min_seq_len=c.min_seq_len,
|
||||||
max_seq_len=float("inf"),
|
max_seq_len=float("inf"),
|
||||||
use_phonemes=False)
|
use_phonemes=False,
|
||||||
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers,
|
||||||
|
)
|
||||||
return dataloader, dataset
|
return dataloader, dataset
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
|
@ -72,18 +74,17 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, " !! Negative values in text_input: {}".format(check_count)
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# TODO: more assertion here
|
# TODO: more assertion here
|
||||||
assert isinstance(speaker_name[0], str)
|
assert isinstance(speaker_name[0], str)
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.audio['num_mels']
|
assert mel_input.shape[2] == c.audio["num_mels"]
|
||||||
# check normalization ranges
|
# check normalization ranges
|
||||||
if self.ap.symmetric_norm:
|
if self.ap.symmetric_norm:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
assert mel_input.max() <= self.ap.max_norm
|
||||||
assert mel_input.min() >= -self.ap.max_norm #pylint: disable=invalid-unary-operand-type
|
assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type
|
||||||
assert mel_input.min() < 0
|
assert mel_input.min() < 0
|
||||||
else:
|
else:
|
||||||
assert mel_input.max() <= self.ap.max_norm
|
assert mel_input.max() <= self.ap.max_norm
|
||||||
|
@ -134,7 +135,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
|
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||||
mel = self.ap.melspectrogram(wav).astype('float32')
|
mel = self.ap.melspectrogram(wav).astype("float32")
|
||||||
mel = torch.FloatTensor(mel).contiguous()
|
mel = torch.FloatTensor(mel).contiguous()
|
||||||
mel_dl = mel_input[0]
|
mel_dl = mel_input[0]
|
||||||
# NOTE: Below needs to check == 0 but due to an unknown reason
|
# NOTE: Below needs to check == 0 but due to an unknown reason
|
||||||
|
@ -145,15 +146,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# check mel-spec correctness
|
# check mel-spec correctness
|
||||||
mel_spec = mel_input[0].cpu().numpy()
|
mel_spec = mel_input[0].cpu().numpy()
|
||||||
wav = self.ap.inv_melspectrogram(mel_spec.T)
|
wav = self.ap.inv_melspectrogram(mel_spec.T)
|
||||||
self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav')
|
self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
|
||||||
shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav')
|
shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav")
|
||||||
|
|
||||||
# check linear-spec
|
# check linear-spec
|
||||||
linear_spec = linear_input[0].cpu().numpy()
|
linear_spec = linear_input[0].cpu().numpy()
|
||||||
wav = self.ap.inv_spectrogram(linear_spec.T)
|
wav = self.ap.inv_spectrogram(linear_spec.T)
|
||||||
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
|
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
|
||||||
shutil.copy(item_idx[0],
|
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
|
||||||
OUTPATH + '/linear_target_dataloader.wav')
|
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert linear_input[0, -1].sum() != 0
|
assert linear_input[0, -1].sum() != 0
|
||||||
|
@ -202,8 +202,8 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert linear_input[1 - idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1, mel_lengths[1]-1] == 1
|
assert stop_target[1, mel_lengths[1] - 1] == 1
|
||||||
assert stop_target[1, mel_lengths[1]:].sum() == 0
|
assert stop_target[1, mel_lengths[1] :].sum() == 0
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch zero-frame conditions (zero-frame disabled)
|
# check batch zero-frame conditions (zero-frame disabled)
|
||||||
|
|
|
@ -6,12 +6,11 @@ from TTS.tts.datasets.preprocess import common_voice
|
||||||
|
|
||||||
|
|
||||||
class TestPreprocessors(unittest.TestCase):
|
class TestPreprocessors(unittest.TestCase):
|
||||||
|
def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
|
||||||
def test_common_voice_preprocessor(self): #pylint: disable=no-self-use
|
|
||||||
root_path = get_tests_input_path()
|
root_path = get_tests_input_path()
|
||||||
meta_file = "common_voice.tsv"
|
meta_file = "common_voice.tsv"
|
||||||
items = common_voice(root_path, meta_file)
|
items = common_voice(root_path, meta_file)
|
||||||
assert items[0][0] == 'The applicants are invited for coffee and visa is given immediately.'
|
assert items[0][0] == "The applicants are invited for coffee and visa is given immediately."
|
||||||
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
|
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
|
||||||
|
|
||||||
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts."
|
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts."
|
||||||
|
|
|
@ -17,9 +17,7 @@ class SpeakerEncoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = SpeakerEncoder(
|
model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||||
input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3
|
|
||||||
)
|
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
output = model.forward(dummy_input)
|
output = model.forward(dummy_input)
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
|
@ -36,9 +34,7 @@ class SpeakerEncoderTests(unittest.TestCase):
|
||||||
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
|
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
|
||||||
assert_diff = (output_norm - output).sum().item()
|
assert_diff = (output_norm - output).sum().item()
|
||||||
assert output.type() == "torch.FloatTensor"
|
assert output.type() == "torch.FloatTensor"
|
||||||
assert (
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
abs(assert_diff) < 1e-4
|
|
||||||
), f" [!] output_norm has wrong values - {assert_diff}"
|
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5)
|
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5)
|
||||||
|
@ -74,6 +70,7 @@ class GE2ELossTests(unittest.TestCase):
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
|
||||||
class AngleProtoLossTests(unittest.TestCase):
|
class AngleProtoLossTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
@ -103,6 +100,7 @@ class AngleProtoLossTests(unittest.TestCase):
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
|
||||||
# class LoaderTest(unittest.TestCase):
|
# class LoaderTest(unittest.TestCase):
|
||||||
# def test_output(self):
|
# def test_output(self):
|
||||||
# items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/")
|
# items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/")
|
||||||
|
|
|
@ -10,11 +10,10 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def test_duration_predictor():
|
def test_duration_predictor():
|
||||||
input_dummy = torch.rand(8, 128, 27).to(device)
|
input_dummy = torch.rand(8, 128, 27).to(device)
|
||||||
input_lengths = torch.randint(20, 27, (8, )).long().to(device)
|
input_lengths = torch.randint(20, 27, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 27
|
input_lengths[-1] = 27
|
||||||
|
|
||||||
x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)),
|
x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||||
1).to(device)
|
|
||||||
|
|
||||||
layer = DurationPredictor(hidden_channels=128).to(device)
|
layer = DurationPredictor(hidden_channels=128).to(device)
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ def test_speedy_speech():
|
||||||
T_de = 74
|
T_de = 74
|
||||||
|
|
||||||
x_dummy = torch.randint(0, 7, (B, T_en)).long().to(device)
|
x_dummy = torch.randint(0, 7, (B, T_en)).long().to(device)
|
||||||
x_lengths = torch.randint(31, T_en, (B, )).long().to(device)
|
x_lengths = torch.randint(31, T_en, (B,)).long().to(device)
|
||||||
x_lengths[-1] = T_en
|
x_lengths[-1] = T_en
|
||||||
|
|
||||||
# set durations. max total duration should be equal to T_de
|
# set durations. max total duration should be equal to T_de
|
||||||
|
@ -53,34 +52,18 @@ def test_speedy_speech():
|
||||||
assert list(o_dr.shape) == [B, T_en]
|
assert list(o_dr.shape) == [B, T_en]
|
||||||
|
|
||||||
# with speaker embedding
|
# with speaker embedding
|
||||||
model = SpeedySpeech(num_chars,
|
model = SpeedySpeech(num_chars, out_channels=80, hidden_channels=128, num_speakers=10, c_in_channels=256).to(device)
|
||||||
out_channels=80,
|
model.forward(x_dummy, x_lengths, y_lengths, durations, g=torch.randint(0, 10, (B,)).to(device))
|
||||||
hidden_channels=128,
|
|
||||||
num_speakers=10,
|
|
||||||
c_in_channels=256).to(device)
|
|
||||||
model.forward(x_dummy,
|
|
||||||
x_lengths,
|
|
||||||
y_lengths,
|
|
||||||
durations,
|
|
||||||
g=torch.randint(0, 10, (B,)).to(device))
|
|
||||||
|
|
||||||
assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
|
assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
|
||||||
assert list(attn.shape) == [B, T_de, T_en]
|
assert list(attn.shape) == [B, T_de, T_en]
|
||||||
assert list(o_dr.shape) == [B, T_en]
|
assert list(o_dr.shape) == [B, T_en]
|
||||||
|
|
||||||
|
|
||||||
# with speaker external embedding
|
# with speaker external embedding
|
||||||
model = SpeedySpeech(num_chars,
|
model = SpeedySpeech(
|
||||||
out_channels=80,
|
num_chars, out_channels=80, hidden_channels=128, num_speakers=10, external_c=True, c_in_channels=256
|
||||||
hidden_channels=128,
|
).to(device)
|
||||||
num_speakers=10,
|
model.forward(x_dummy, x_lengths, y_lengths, durations, g=torch.rand((B, 256)).to(device))
|
||||||
external_c=True,
|
|
||||||
c_in_channels=256).to(device)
|
|
||||||
model.forward(x_dummy,
|
|
||||||
x_lengths,
|
|
||||||
y_lengths,
|
|
||||||
durations,
|
|
||||||
g=torch.rand((B, 256)).to(device))
|
|
||||||
|
|
||||||
assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
|
assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
|
||||||
assert list(attn.shape) == [B, T_de, T_en]
|
assert list(attn.shape) == [B, T_de, T_en]
|
||||||
|
|
|
@ -4,5 +4,5 @@ from TTS.tts.utils.text import phonemes
|
||||||
|
|
||||||
|
|
||||||
class SymbolsTest(unittest.TestCase):
|
class SymbolsTest(unittest.TestCase):
|
||||||
def test_uniqueness(self): #pylint: disable=no-self-use
|
def test_uniqueness(self): # pylint: disable=no-self-use
|
||||||
assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes)))
|
assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes)))
|
||||||
|
|
|
@ -14,8 +14,8 @@ class SynthesizerTest(unittest.TestCase):
|
||||||
def _create_random_model(self):
|
def _create_random_model(self):
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
global symbols, phonemes
|
global symbols, phonemes
|
||||||
config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
|
config = load_config(os.path.join(get_tests_output_path(), "dummy_model_config.json"))
|
||||||
if 'characters' in config.keys():
|
if "characters" in config.keys():
|
||||||
symbols, phonemes = make_symbols(**config.characters)
|
symbols, phonemes = make_symbols(**config.characters)
|
||||||
|
|
||||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||||
|
@ -25,11 +25,11 @@ class SynthesizerTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
self._create_random_model()
|
self._create_random_model()
|
||||||
config = load_config(os.path.join(get_tests_input_path(), 'server_config.json'))
|
config = load_config(os.path.join(get_tests_input_path(), "server_config.json"))
|
||||||
tts_root_path = get_tests_output_path()
|
tts_root_path = get_tests_output_path()
|
||||||
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
|
config["tts_checkpoint"] = os.path.join(tts_root_path, config["tts_checkpoint"])
|
||||||
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
|
config["tts_config"] = os.path.join(tts_root_path, config["tts_config"])
|
||||||
synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None)
|
synthesizer = Synthesizer(config["tts_checkpoint"], config["tts_config"], None, None)
|
||||||
synthesizer.tts("Better this test works!!")
|
synthesizer.tts("Better this test works!!")
|
||||||
|
|
||||||
def test_split_into_sentences(self):
|
def test_split_into_sentences(self):
|
||||||
|
@ -38,20 +38,48 @@ class SynthesizerTest(unittest.TestCase):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.seg = Synthesizer.get_segmenter("en")
|
self.seg = Synthesizer.get_segmenter("en")
|
||||||
sis = Synthesizer.split_into_sentences
|
sis = Synthesizer.split_into_sentences
|
||||||
assert sis(self, 'Hello. Two sentences') == ['Hello.', 'Two sentences']
|
assert sis(self, "Hello. Two sentences") == ["Hello.", "Two sentences"]
|
||||||
assert sis(self, 'He went to meet the adviser from Scott, Waltman & Co. next morning.') == ['He went to meet the adviser from Scott, Waltman & Co. next morning.']
|
assert sis(self, "He went to meet the adviser from Scott, Waltman & Co. next morning.") == [
|
||||||
assert sis(self, 'Let\'s run it past Sarah and co. They\'ll want to see this.') == ['Let\'s run it past Sarah and co.', 'They\'ll want to see this.']
|
"He went to meet the adviser from Scott, Waltman & Co. next morning."
|
||||||
assert sis(self, 'Where is Bobby Jr.\'s rabbit?') == ['Where is Bobby Jr.\'s rabbit?']
|
]
|
||||||
assert sis(self, 'Please inform the U.K. authorities right away.') == ['Please inform the U.K. authorities right away.']
|
assert sis(self, "Let's run it past Sarah and co. They'll want to see this.") == [
|
||||||
assert sis(self, 'Were David and co. at the event?') == ['Were David and co. at the event?']
|
"Let's run it past Sarah and co.",
|
||||||
assert sis(self, 'paging dr. green, please come to theatre four immediately.') == ['paging dr. green, please come to theatre four immediately.']
|
"They'll want to see this.",
|
||||||
assert sis(self, 'The email format is Firstname.Lastname@example.com. I think you reversed them.') == ['The email format is Firstname.Lastname@example.com.', 'I think you reversed them.']
|
]
|
||||||
assert sis(self, 'The demo site is: https://top100.example.com/subsection/latestnews.html. Please send us your feedback.') == ['The demo site is: https://top100.example.com/subsection/latestnews.html.', 'Please send us your feedback.']
|
assert sis(self, "Where is Bobby Jr.'s rabbit?") == ["Where is Bobby Jr.'s rabbit?"]
|
||||||
assert sis(self, 'Scowling at him, \'You are not done yet!\' she yelled.') == ['Scowling at him, \'You are not done yet!\' she yelled.'] # with the final lowercase "she" we see it's all one sentence
|
assert sis(self, "Please inform the U.K. authorities right away.") == [
|
||||||
assert sis(self, 'Hey!! So good to see you.') == ['Hey!!', 'So good to see you.']
|
"Please inform the U.K. authorities right away."
|
||||||
assert sis(self, 'He went to Yahoo! but I don\'t know the division.') == ['He went to Yahoo! but I don\'t know the division.']
|
]
|
||||||
assert sis(self, 'If you can\'t remember a quote, “at least make up a memorable one that\'s plausible..."') == ['If you can\'t remember a quote, “at least make up a memorable one that\'s plausible..."']
|
assert sis(self, "Were David and co. at the event?") == ["Were David and co. at the event?"]
|
||||||
assert sis(self, 'The address is not google.com.') == ['The address is not google.com.']
|
assert sis(self, "paging dr. green, please come to theatre four immediately.") == [
|
||||||
assert sis(self, '1.) The first item 2.) The second item') == ['1.) The first item', '2.) The second item']
|
"paging dr. green, please come to theatre four immediately."
|
||||||
assert sis(self, '1) The first item 2) The second item') == ['1) The first item', '2) The second item']
|
]
|
||||||
assert sis(self, 'a. The first item b. The second item c. The third list item') == ['a. The first item', 'b. The second item', 'c. The third list item']
|
assert sis(self, "The email format is Firstname.Lastname@example.com. I think you reversed them.") == [
|
||||||
|
"The email format is Firstname.Lastname@example.com.",
|
||||||
|
"I think you reversed them.",
|
||||||
|
]
|
||||||
|
assert sis(
|
||||||
|
self,
|
||||||
|
"The demo site is: https://top100.example.com/subsection/latestnews.html. Please send us your feedback.",
|
||||||
|
) == [
|
||||||
|
"The demo site is: https://top100.example.com/subsection/latestnews.html.",
|
||||||
|
"Please send us your feedback.",
|
||||||
|
]
|
||||||
|
assert sis(self, "Scowling at him, 'You are not done yet!' she yelled.") == [
|
||||||
|
"Scowling at him, 'You are not done yet!' she yelled."
|
||||||
|
] # with the final lowercase "she" we see it's all one sentence
|
||||||
|
assert sis(self, "Hey!! So good to see you.") == ["Hey!!", "So good to see you."]
|
||||||
|
assert sis(self, "He went to Yahoo! but I don't know the division.") == [
|
||||||
|
"He went to Yahoo! but I don't know the division."
|
||||||
|
]
|
||||||
|
assert sis(self, "If you can't remember a quote, “at least make up a memorable one that's plausible...\"") == [
|
||||||
|
"If you can't remember a quote, “at least make up a memorable one that's plausible...\""
|
||||||
|
]
|
||||||
|
assert sis(self, "The address is not google.com.") == ["The address is not google.com."]
|
||||||
|
assert sis(self, "1.) The first item 2.) The second item") == ["1.) The first item", "2.) The second item"]
|
||||||
|
assert sis(self, "1) The first item 2) The second item") == ["1) The first item", "2) The second item"]
|
||||||
|
assert sis(self, "a. The first item b. The second item c. The third list item") == [
|
||||||
|
"a. The first item",
|
||||||
|
"b. The second item",
|
||||||
|
"c. The third list item",
|
||||||
|
]
|
||||||
|
|
|
@ -11,13 +11,13 @@ from TTS.tts.models.tacotron2 import Tacotron2
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
c = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
@ -26,20 +26,19 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
def test_train_step(self): # pylint: disable=no-self-use
|
def test_train_step(self): # pylint: disable=no-self-use
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[0] = 30
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
|
@ -48,14 +47,14 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -66,13 +65,12 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,20 +78,19 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[0] = 30
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
|
@ -102,14 +99,14 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -120,39 +117,46 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
#pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[0] = 30
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device)
|
model = Tacotron2(
|
||||||
|
num_chars=24,
|
||||||
|
r=c.r,
|
||||||
|
num_speakers=5,
|
||||||
|
gst=True,
|
||||||
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
|
).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -162,7 +166,8 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -177,36 +182,45 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
name, param = name_param
|
name, param = name_param
|
||||||
if name == 'gst_layer.encoder.recurrence.weight_hh_l0':
|
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
|
||||||
#print(param.grad)
|
# print(param.grad)
|
||||||
continue
|
continue
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} {} with shape {} not updated!! \n{}\n{}".format(
|
name, count, param.shape, param, param_ref
|
||||||
name, count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# with file gst style
|
# with file gst style
|
||||||
mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :30].unsqueeze(0).transpose(1, 2).to(device)
|
mel_spec = (
|
||||||
|
torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :30].unsqueeze(0).transpose(1, 2).to(device)
|
||||||
|
)
|
||||||
mel_spec = mel_spec.repeat(8, 1, 1)
|
mel_spec = mel_spec.repeat(8, 1, 1)
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[0] = 30
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device)
|
model = Tacotron2(
|
||||||
|
num_chars=24,
|
||||||
|
r=c.r,
|
||||||
|
num_speakers=5,
|
||||||
|
gst=True,
|
||||||
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
|
).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -216,7 +230,8 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -231,47 +246,57 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
name, param = name_param
|
name, param = name_param
|
||||||
if name == 'gst_layer.encoder.recurrence.weight_hh_l0':
|
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
|
||||||
#print(param.grad)
|
# print(param.grad)
|
||||||
continue
|
continue
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} {} with shape {} not updated!! \n{}\n{}".format(
|
name, count, param.shape, param, param_ref
|
||||||
name, count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[0] = 30
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, speaker_embedding_dim=55, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens'], gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding']).to(device)
|
model = Tacotron2(
|
||||||
|
num_chars=24,
|
||||||
|
r=c.r,
|
||||||
|
num_speakers=5,
|
||||||
|
speaker_embedding_dim=55,
|
||||||
|
gst=True,
|
||||||
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
|
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
|
||||||
|
).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -282,14 +307,13 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for name_param, param_ref in zip(model.named_parameters(),
|
for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
name, param = name_param
|
name, param = name_param
|
||||||
if name == 'gst_layer.encoder.recurrence.weight_hh_l0':
|
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
|
||||||
continue
|
continue
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
|
@ -10,48 +10,51 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
tf.get_logger().setLevel('INFO')
|
tf.get_logger().setLevel("INFO")
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=unused-variable
|
||||||
#pylint: disable=unused-variable
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
c = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
|
|
||||||
class TacotronTFTrainTest(unittest.TestCase):
|
class TacotronTFTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_dummy_inputs():
|
def generate_dummy_inputs():
|
||||||
chars_seq = torch.randint(0, 24, (8, 128)).long().to(device)
|
chars_seq = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
chars_seq_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
chars_seq_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0]
|
chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0]
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy())
|
chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy())
|
||||||
chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy())
|
chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy())
|
||||||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
||||||
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
|
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids
|
||||||
stop_targets, speaker_ids
|
|
||||||
|
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
''' test forward pass '''
|
""" test forward pass """
|
||||||
chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
|
(
|
||||||
stop_targets, speaker_ids = self.generate_dummy_inputs()
|
chars_seq,
|
||||||
|
chars_seq_lengths,
|
||||||
|
mel_spec,
|
||||||
|
mel_postnet_spec,
|
||||||
|
mel_lengths,
|
||||||
|
stop_targets,
|
||||||
|
speaker_ids,
|
||||||
|
) = self.generate_dummy_inputs()
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(chars_seq.shape[0],
|
stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
|
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
|
||||||
|
@ -68,15 +71,23 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(chars_seq, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
||||||
def test_forward_attention(self,):
|
def test_forward_attention(
|
||||||
chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
|
self,
|
||||||
stop_targets, speaker_ids = self.generate_dummy_inputs()
|
):
|
||||||
|
(
|
||||||
|
chars_seq,
|
||||||
|
chars_seq_lengths,
|
||||||
|
mel_spec,
|
||||||
|
mel_postnet_spec,
|
||||||
|
mel_lengths,
|
||||||
|
stop_targets,
|
||||||
|
speaker_ids,
|
||||||
|
) = self.generate_dummy_inputs()
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(chars_seq.shape[0],
|
stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True)
|
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True)
|
||||||
|
@ -93,16 +104,19 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(chars_seq, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
||||||
def test_tflite_conversion(self, ): #pylint:disable=no-self-use
|
def test_tflite_conversion(
|
||||||
model = Tacotron2(num_chars=24,
|
self,
|
||||||
|
): # pylint:disable=no-self-use
|
||||||
|
model = Tacotron2(
|
||||||
|
num_chars=24,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
r=3,
|
r=3,
|
||||||
postnet_output_dim=80,
|
postnet_output_dim=80,
|
||||||
decoder_output_dim=80,
|
decoder_output_dim=80,
|
||||||
attn_type='original',
|
attn_type="original",
|
||||||
attn_win=False,
|
attn_win=False,
|
||||||
attn_norm='sigmoid',
|
attn_norm="sigmoid",
|
||||||
prenet_type='original',
|
prenet_type="original",
|
||||||
prenet_dropout=True,
|
prenet_dropout=True,
|
||||||
forward_attn=False,
|
forward_attn=False,
|
||||||
trans_agent=False,
|
trans_agent=False,
|
||||||
|
@ -111,27 +125,30 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
attn_K=0,
|
attn_K=0,
|
||||||
separate_stopnet=True,
|
separate_stopnet=True,
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
enable_tflite=True)
|
enable_tflite=True,
|
||||||
|
)
|
||||||
model.build_inference()
|
model.build_inference()
|
||||||
convert_tacotron2_to_tflite(model, output_path='test_tacotron2.tflite', experimental_converter=True)
|
convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True)
|
||||||
# init tflite model
|
# init tflite model
|
||||||
tflite_model = load_tflite_model('test_tacotron2.tflite')
|
tflite_model = load_tflite_model("test_tacotron2.tflite")
|
||||||
# fake input
|
# fake input
|
||||||
inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) #pylint:disable=unexpected-keyword-arg
|
inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) # pylint:disable=unexpected-keyword-arg
|
||||||
# run inference
|
# run inference
|
||||||
# get input and output details
|
# get input and output details
|
||||||
input_details = tflite_model.get_input_details()
|
input_details = tflite_model.get_input_details()
|
||||||
output_details = tflite_model.get_output_details()
|
output_details = tflite_model.get_output_details()
|
||||||
# reshape input tensor for the new input shape
|
# reshape input tensor for the new input shape
|
||||||
tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) #pylint:disable=unexpected-keyword-arg
|
tflite_model.resize_tensor_input(
|
||||||
|
input_details[0]["index"], inputs.shape
|
||||||
|
) # pylint:disable=unexpected-keyword-arg
|
||||||
tflite_model.allocate_tensors()
|
tflite_model.allocate_tensors()
|
||||||
detail = input_details[0]
|
detail = input_details[0]
|
||||||
input_shape = detail['shape']
|
input_shape = detail["shape"]
|
||||||
tflite_model.set_tensor(detail['index'], inputs)
|
tflite_model.set_tensor(detail["index"], inputs)
|
||||||
# run the tflite_model
|
# run the tflite_model
|
||||||
tflite_model.invoke()
|
tflite_model.invoke()
|
||||||
# collect outputs
|
# collect outputs
|
||||||
decoder_output = tflite_model.get_tensor(output_details[0]['index'])
|
decoder_output = tflite_model.get_tensor(output_details[0]["index"])
|
||||||
postnet_output = tflite_model.get_tensor(output_details[1]['index'])
|
postnet_output = tflite_model.get_tensor(output_details[1]["index"])
|
||||||
# remove tflite binary
|
# remove tflite binary
|
||||||
os.remove('test_tacotron2.tflite')
|
os.remove("test_tacotron2.tflite")
|
||||||
|
|
|
@ -11,13 +11,13 @@ from TTS.tts.models.tacotron import Tacotron
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
c = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
@ -32,147 +32,140 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=5,
|
num_speakers=5,
|
||||||
postnet_output_dim=c.audio['fft_size'],
|
postnet_output_dim=c.audio["fft_size"],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size
|
memory_size=c.memory_size,
|
||||||
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
).to(
|
||||||
|
device
|
||||||
|
) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s" %
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec,
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=5,
|
num_speakers=5,
|
||||||
postnet_output_dim=c.audio['fft_size'],
|
postnet_output_dim=c.audio["fft_size"],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size,
|
memory_size=c.memory_size,
|
||||||
speaker_embedding_dim=55,
|
speaker_embedding_dim=55,
|
||||||
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
).to(
|
||||||
|
device
|
||||||
|
) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s" %
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths,
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings
|
||||||
speaker_embeddings=speaker_embeddings)
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec,
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 120, c.audio["num_mels"]).to(device)
|
||||||
linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device)
|
linear_spec = torch.rand(8, 120, c.audio["fft_size"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 120, (8,)).long().to(device)
|
||||||
mel_lengths[-1] = 120
|
mel_lengths[-1] = 120
|
||||||
stop_targets = torch.zeros(8, 120, 1).float().to(device)
|
stop_targets = torch.zeros(8, 120, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
@ -180,65 +173,64 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=5,
|
num_speakers=5,
|
||||||
gst=True,
|
gst=True,
|
||||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
gst_num_heads=c.gst['gst_num_heads'],
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
postnet_output_dim=c.audio['fft_size'],
|
postnet_output_dim=c.audio["fft_size"],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size
|
memory_size=c.memory_size,
|
||||||
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
).to(
|
||||||
|
device
|
||||||
|
) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
# print(model)
|
# print(model)
|
||||||
print(" > Num parameters for Tacotron GST model:%s" %
|
print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec,
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# with file gst style
|
# with file gst style
|
||||||
mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :120].unsqueeze(0).transpose(1, 2).to(device)
|
mel_spec = (
|
||||||
|
torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :120].unsqueeze(0).transpose(1, 2).to(device)
|
||||||
|
)
|
||||||
mel_spec = mel_spec.repeat(8, 1, 1)
|
mel_spec = mel_spec.repeat(8, 1, 1)
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
linear_spec = torch.rand(8, mel_spec.size(1), c.audio['fft_size']).to(device)
|
linear_spec = torch.rand(8, mel_spec.size(1), c.audio["fft_size"]).to(device)
|
||||||
mel_lengths = torch.randint(20, mel_spec.size(1), (8, )).long().to(device)
|
mel_lengths = torch.randint(20, mel_spec.size(1), (8,)).long().to(device)
|
||||||
mel_lengths[-1] = mel_spec.size(1)
|
mel_lengths[-1] = mel_spec.size(1)
|
||||||
stop_targets = torch.zeros(8, mel_spec.size(1), 1).float().to(device)
|
stop_targets = torch.zeros(8, mel_spec.size(1), 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
@ -246,113 +238,109 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=5,
|
num_speakers=5,
|
||||||
gst=True,
|
gst=True,
|
||||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
gst_num_heads=c.gst['gst_num_heads'],
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
postnet_output_dim=c.audio['fft_size'],
|
postnet_output_dim=c.audio["fft_size"],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size
|
memory_size=c.memory_size,
|
||||||
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
).to(
|
||||||
|
device
|
||||||
|
) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
# print(model)
|
# print(model)
|
||||||
print(" > Num parameters for Tacotron GST model:%s" %
|
print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||||
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec,
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
mel_lengths[-1] = mel_spec.size(1)
|
mel_lengths[-1] = mel_spec.size(1)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
stop_targets = (stop_targets.sum(2) >
|
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
num_speakers=5,
|
num_speakers=5,
|
||||||
postnet_output_dim=c.audio['fft_size'],
|
postnet_output_dim=c.audio["fft_size"],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio["num_mels"],
|
||||||
gst=True,
|
gst=True,
|
||||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||||
gst_num_heads=c.gst['gst_num_heads'],
|
gst_num_heads=c.gst["gst_num_heads"],
|
||||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||||
gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'],
|
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size,
|
memory_size=c.memory_size,
|
||||||
speaker_embedding_dim=55,
|
speaker_embedding_dim=55,
|
||||||
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
).to(
|
||||||
|
device
|
||||||
|
) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s" %
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
(count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths,
|
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings
|
||||||
speaker_embeddings=speaker_embeddings)
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec,
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for name_param, param_ref in zip(model.named_parameters(),
|
for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
name, param = name_param
|
name, param = name_param
|
||||||
if name == 'gst_layer.encoder.recurrence.weight_hh_l0':
|
if name == "gst_layer.encoder.recurrence.weight_hh_l0":
|
||||||
continue
|
continue
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
|
@ -17,5 +17,5 @@ def test_currency() -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_expand_numbers() -> None:
|
def test_expand_numbers() -> None:
|
||||||
assert phoneme_cleaners("-1") == 'minus one'
|
assert phoneme_cleaners("-1") == "minus one"
|
||||||
assert phoneme_cleaners("1") == 'one'
|
assert phoneme_cleaners("1") == "one"
|
||||||
|
|
|
@ -7,7 +7,8 @@ from tests import get_tests_input_path, get_tests_path
|
||||||
from TTS.tts.utils.text import *
|
from TTS.tts.utils.text import *
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
conf = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
|
|
||||||
def test_phoneme_to_sequence():
|
def test_phoneme_to_sequence():
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ def test_phoneme_to_sequence():
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||||
gt = 'ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!'
|
gt = "ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!"
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
|
||||||
# multiple punctuations
|
# multiple punctuations
|
||||||
|
@ -87,6 +88,7 @@ def test_phoneme_to_sequence():
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
|
||||||
|
|
||||||
def test_phoneme_to_sequence_with_blank_token():
|
def test_phoneme_to_sequence_with_blank_token():
|
||||||
|
|
||||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||||
|
@ -105,7 +107,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?'
|
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
@ -116,7 +118,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ'
|
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
@ -127,7 +129,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||||
gt = 'biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!'
|
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
@ -138,7 +140,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.'
|
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ."
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
@ -165,9 +167,10 @@ def test_phoneme_to_sequence_with_blank_token():
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == text_hat_with_params == gt
|
assert text_hat == text_hat_with_params == gt
|
||||||
|
|
||||||
|
|
||||||
def test_text2phone():
|
def test_text2phone():
|
||||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||||
gt = 'ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!'
|
gt = "ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||||
lang = "en-us"
|
lang = "en-us"
|
||||||
ph = text2phone(text, lang)
|
ph = text2phone(text, lang)
|
||||||
assert gt == ph
|
assert gt == ph
|
||||||
|
|
|
@ -13,17 +13,20 @@ file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
||||||
os.makedirs(OUTPATH, exist_ok=True)
|
os.makedirs(OUTPATH, exist_ok=True)
|
||||||
|
|
||||||
C = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
C = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
|
|
||||||
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
||||||
ok_ljspeech = os.path.exists(test_data_path)
|
ok_ljspeech = os.path.exists(test_data_path)
|
||||||
|
|
||||||
|
|
||||||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
def gan_dataset_case(
|
||||||
''' run dataloader with given parameters and check conditions '''
|
batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers
|
||||||
|
):
|
||||||
|
""" run dataloader with given parameters and check conditions """
|
||||||
ap = AudioProcessor(**C.audio)
|
ap = AudioProcessor(**C.audio)
|
||||||
_, train_items = load_wav_data(test_data_path, 10)
|
_, train_items = load_wav_data(test_data_path, 10)
|
||||||
dataset = GANDataset(ap,
|
dataset = GANDataset(
|
||||||
|
ap,
|
||||||
train_items,
|
train_items,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
hop_len=hop_len,
|
hop_len=hop_len,
|
||||||
|
@ -31,13 +34,11 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
||||||
conv_pad=conv_pad,
|
conv_pad=conv_pad,
|
||||||
return_segments=return_segments,
|
return_segments=return_segments,
|
||||||
use_noise_augment=use_noise_augment,
|
use_noise_augment=use_noise_augment,
|
||||||
use_cache=use_cache)
|
use_cache=use_cache,
|
||||||
loader = DataLoader(dataset=dataset,
|
)
|
||||||
batch_size=batch_size,
|
loader = DataLoader(
|
||||||
shuffle=True,
|
dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True
|
||||||
num_workers=num_workers,
|
)
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True)
|
|
||||||
|
|
||||||
max_iter = 10
|
max_iter = 10
|
||||||
count_iter = 0
|
count_iter = 0
|
||||||
|
@ -61,8 +62,8 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
||||||
mel = ap.melspectrogram(audio)
|
mel = ap.melspectrogram(audio)
|
||||||
# the first 2 and the last 2 frames are skipped due to the padding
|
# the first 2 and the last 2 frames are skipped due to the padding
|
||||||
# differences in stft
|
# differences in stft
|
||||||
max_diff = abs((feat - mel[:, :feat1.shape[-1]])[:, 2:-2]).max()
|
max_diff = abs((feat - mel[:, : feat1.shape[-1]])[:, 2:-2]).max()
|
||||||
assert max_diff <= 0, f' [!] {max_diff}'
|
assert max_diff <= 0, f" [!] {max_diff}"
|
||||||
|
|
||||||
count_iter += 1
|
count_iter += 1
|
||||||
# if count_iter == max_iter:
|
# if count_iter == max_iter:
|
||||||
|
@ -79,17 +80,17 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
||||||
|
|
||||||
|
|
||||||
def test_parametrized_gan_dataset():
|
def test_parametrized_gan_dataset():
|
||||||
''' test dataloader with different parameters '''
|
""" test dataloader with different parameters """
|
||||||
params = [
|
params = [
|
||||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
[32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, 0],
|
||||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 4],
|
[32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, 4],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, True, 0],
|
||||||
[1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, 0],
|
[1, C.audio["hop_length"], C.audio["hop_length"], 0, True, True, True, 0],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, True, True, True, 0],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, False, True, True, 0],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, 0],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, 0],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, False, False, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, False, False, False, 0],
|
||||||
]
|
]
|
||||||
for param in params:
|
for param in params:
|
||||||
print(param)
|
print(param)
|
||||||
|
|
|
@ -14,7 +14,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
C = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
C = load_config(os.path.join(get_tests_input_path(), "test_config.json"))
|
||||||
ap = AudioProcessor(**C.audio)
|
ap = AudioProcessor(**C.audio)
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +45,8 @@ def test_multiscale_stft_loss():
|
||||||
stft_loss = MultiScaleSTFTLoss(
|
stft_loss = MultiScaleSTFTLoss(
|
||||||
[ap.fft_size // 2, ap.fft_size, ap.fft_size * 2],
|
[ap.fft_size // 2, ap.fft_size, ap.fft_size * 2],
|
||||||
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
||||||
[ap.win_length // 2, ap.win_length, ap.win_length * 2])
|
[ap.win_length // 2, ap.win_length, ap.win_length * 2],
|
||||||
|
)
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
loss_m, loss_sc = stft_loss(wav, wav)
|
loss_m, loss_sc = stft_loss(wav, wav)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.vocoder.models.parallel_wavegan_discriminator import (ParallelWaveganDiscriminator,
|
from TTS.vocoder.models.parallel_wavegan_discriminator import (
|
||||||
ResidualParallelWaveganDiscriminator)
|
ParallelWaveganDiscriminator,
|
||||||
|
ResidualParallelWaveganDiscriminator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_pwgan_disciminator():
|
def test_pwgan_disciminator():
|
||||||
|
@ -15,7 +17,8 @@ def test_pwgan_disciminator():
|
||||||
dilation_factor=1,
|
dilation_factor=1,
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
bias=True)
|
bias=True,
|
||||||
|
)
|
||||||
dummy_x = torch.rand((4, 1, 64 * 256))
|
dummy_x = torch.rand((4, 1, 64 * 256))
|
||||||
output = model(dummy_x)
|
output = model(dummy_x)
|
||||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||||
|
@ -35,7 +38,8 @@ def test_redisual_pwgan_disciminator():
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2})
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
|
)
|
||||||
dummy_x = torch.rand((4, 1, 64 * 256))
|
dummy_x = torch.rand((4, 1, 64 * 256))
|
||||||
output = model(dummy_x)
|
output = model(dummy_x)
|
||||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||||
|
|
|
@ -18,7 +18,8 @@ def test_pwgan_generator():
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
upsample_factors=[4, 4, 4, 4])
|
upsample_factors=[4, 4, 4, 4],
|
||||||
|
)
|
||||||
dummy_c = torch.rand((2, 80, 5))
|
dummy_c = torch.rand((2, 80, 5))
|
||||||
output = model(dummy_c)
|
output = model(dummy_c)
|
||||||
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
|
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
|
||||||
|
|
|
@ -23,5 +23,4 @@ def test_pqmf():
|
||||||
print(w2_.max())
|
print(w2_.max())
|
||||||
print(w2_.min())
|
print(w2_.min())
|
||||||
print(w2_.mean())
|
print(w2_.mean())
|
||||||
sf.write(os.path.join(get_tests_output_path(), 'pqmf_output.wav'),
|
sf.write(os.path.join(get_tests_output_path(), "pqmf_output.wav"), w2_.flatten().detach(), sr)
|
||||||
w2_.flatten().detach(), sr)
|
|
||||||
|
|
|
@ -5,14 +5,12 @@ from TTS.vocoder.models.random_window_discriminator import RandomWindowDiscrimin
|
||||||
|
|
||||||
|
|
||||||
def test_rwd():
|
def test_rwd():
|
||||||
layer = RandomWindowDiscriminator(cond_channels=80,
|
layer = RandomWindowDiscriminator(
|
||||||
window_sizes=(512, 1024, 2048, 4096,
|
cond_channels=80,
|
||||||
8192),
|
window_sizes=(512, 1024, 2048, 4096, 8192),
|
||||||
cond_disc_downsample_factors=[
|
cond_disc_downsample_factors=[(8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)],
|
||||||
(8, 4, 2, 2, 2), (8, 4, 2, 2),
|
hop_length=256,
|
||||||
(8, 4, 2), (8, 4), (4, 2, 2)
|
)
|
||||||
],
|
|
||||||
hop_length=256)
|
|
||||||
x = torch.rand([4, 1, 22050])
|
x = torch.rand([4, 1, 22050])
|
||||||
c = torch.rand([4, 80, 22050 // 256])
|
c = torch.rand([4, 80, 22050 // 256])
|
||||||
|
|
||||||
|
|
|
@ -24,5 +24,4 @@ def test_pqmf():
|
||||||
print(w2_.max())
|
print(w2_.max())
|
||||||
print(w2_.min())
|
print(w2_.min())
|
||||||
print(w2_.mean())
|
print(w2_.mean())
|
||||||
sf.write(os.path.join(get_tests_output_path(), 'tf_pqmf_output.wav'),
|
sf.write(os.path.join(get_tests_output_path(), "tf_pqmf_output.wav"), w2_.flatten(), sr)
|
||||||
w2_.flatten(), sr)
|
|
||||||
|
|
|
@ -14,8 +14,7 @@ file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
|
||||||
os.makedirs(OUTPATH, exist_ok=True)
|
os.makedirs(OUTPATH, exist_ok=True)
|
||||||
|
|
||||||
C = load_config(os.path.join(get_tests_input_path(),
|
C = load_config(os.path.join(get_tests_input_path(), "test_vocoder_wavernn_config.json"))
|
||||||
"test_vocoder_wavernn_config.json"))
|
|
||||||
|
|
||||||
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
|
||||||
test_mel_feat_path = os.path.join(test_data_path, "mel")
|
test_mel_feat_path = os.path.join(test_data_path, "mel")
|
||||||
|
@ -33,19 +32,14 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor
|
||||||
C.data_path = test_data_path
|
C.data_path = test_data_path
|
||||||
|
|
||||||
preprocess_wav_files(test_data_path, C, ap)
|
preprocess_wav_files(test_data_path, C, ap)
|
||||||
_, train_items = load_wav_feat_data(
|
_, train_items = load_wav_feat_data(test_data_path, test_mel_feat_path, 5)
|
||||||
test_data_path, test_mel_feat_path, 5)
|
|
||||||
|
|
||||||
dataset = WaveRNNDataset(ap=ap,
|
dataset = WaveRNNDataset(
|
||||||
items=train_items,
|
ap=ap, items=train_items, seq_len=seq_len, hop_len=hop_len, pad=pad, mode=mode, mulaw=mulaw
|
||||||
seq_len=seq_len,
|
|
||||||
hop_len=hop_len,
|
|
||||||
pad=pad,
|
|
||||||
mode=mode,
|
|
||||||
mulaw=mulaw
|
|
||||||
)
|
)
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(dataset,
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=dataset.collate,
|
collate_fn=dataset.collate,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -59,10 +53,8 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor
|
||||||
try:
|
try:
|
||||||
for data in loader:
|
for data in loader:
|
||||||
x_input, mels, _ = data
|
x_input, mels, _ = data
|
||||||
expected_feat_shape = (ap.num_mels,
|
expected_feat_shape = (ap.num_mels, (x_input.shape[-1] // hop_len) + (pad * 2))
|
||||||
(x_input.shape[-1] // hop_len) + (pad * 2))
|
assert np.all(mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}"
|
||||||
assert np.all(
|
|
||||||
mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}"
|
|
||||||
|
|
||||||
assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1]
|
assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1]
|
||||||
count_iter += 1
|
count_iter += 1
|
||||||
|
@ -77,15 +69,15 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor
|
||||||
|
|
||||||
|
|
||||||
def test_parametrized_wavernn_dataset():
|
def test_parametrized_wavernn_dataset():
|
||||||
''' test dataloader with different parameters '''
|
""" test dataloader with different parameters """
|
||||||
params = [
|
params = [
|
||||||
[16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0],
|
[16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 10, True, 0],
|
||||||
[16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4],
|
[16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, "mold", False, 4],
|
||||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0],
|
[1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 9, False, 0],
|
||||||
[1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0],
|
[1, C.audio["hop_length"], C.audio["hop_length"], 2, 10, True, 0],
|
||||||
[1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0],
|
[1, C.audio["hop_length"], C.audio["hop_length"], 2, "mold", False, 0],
|
||||||
[1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2],
|
[1, C.audio["hop_length"] * 5, C.audio["hop_length"], 4, 10, False, 2],
|
||||||
[1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0],
|
[1, C.audio["hop_length"] * 5, C.audio["hop_length"], 2, "mold", False, 0],
|
||||||
]
|
]
|
||||||
for param in params:
|
for param in params:
|
||||||
print(param)
|
print(param)
|
||||||
|
|
|
@ -75,12 +75,12 @@ def test_wavegrad_forward():
|
||||||
c = torch.rand(32, 80, 20)
|
c = torch.rand(32, 80, 20)
|
||||||
noise_scale = torch.rand(32)
|
noise_scale = torch.rand(32)
|
||||||
|
|
||||||
model = Wavegrad(in_channels=80,
|
model = Wavegrad(
|
||||||
|
in_channels=80,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
upsample_factors=[5, 5, 3, 2, 2],
|
upsample_factors=[5, 5, 3, 2, 2],
|
||||||
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
|
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||||
[1, 2, 4, 8], [1, 2, 4, 8],
|
)
|
||||||
[1, 2, 4, 8]])
|
|
||||||
o = model.forward(x, c, noise_scale)
|
o = model.forward(x, c, noise_scale)
|
||||||
|
|
||||||
assert o.shape[0] == 32
|
assert o.shape[0] == 32
|
||||||
|
|
|
@ -6,7 +6,7 @@ from torch import optim
|
||||||
|
|
||||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -20,19 +20,19 @@ class WavegradTrainTest(unittest.TestCase):
|
||||||
mel_spec = torch.rand(8, 80, 20).to(device)
|
mel_spec = torch.rand(8, 80, 20).to(device)
|
||||||
|
|
||||||
criterion = torch.nn.L1Loss().to(device)
|
criterion = torch.nn.L1Loss().to(device)
|
||||||
model = Wavegrad(in_channels=80,
|
model = Wavegrad(
|
||||||
|
in_channels=80,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
upsample_factors=[5, 5, 3, 2, 2],
|
upsample_factors=[5, 5, 3, 2, 2],
|
||||||
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
|
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||||
[1, 2, 4, 8], [1, 2, 4, 8],
|
)
|
||||||
[1, 2, 4, 8]])
|
|
||||||
|
|
||||||
model_ref = Wavegrad(in_channels=80,
|
model_ref = Wavegrad(
|
||||||
|
in_channels=80,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
upsample_factors=[5, 5, 3, 2, 2],
|
upsample_factors=[5, 5, 3, 2, 2],
|
||||||
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
|
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||||
[1, 2, 4, 8], [1, 2, 4, 8],
|
)
|
||||||
[1, 2, 4, 8]])
|
|
||||||
model.train()
|
model.train()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
betas = np.linspace(1e-6, 1e-2, 1000)
|
betas = np.linspace(1e-6, 1e-2, 1000)
|
||||||
|
@ -40,8 +40,7 @@ class WavegradTrainTest(unittest.TestCase):
|
||||||
model_ref.load_state_dict(model.state_dict())
|
model_ref.load_state_dict(model.state_dict())
|
||||||
model_ref.to(device)
|
model_ref.to(device)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
@ -53,11 +52,10 @@ class WavegradTrainTest(unittest.TestCase):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(),
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
# if count not in [145, 59]:
|
# if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
count, param.shape, param, param_ref
|
||||||
count, param.shape, param, param_ref)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
Loading…
Reference in New Issue