mirror of https://github.com/coqui-ai/TTS.git
Ruff autofix C41
This commit is contained in:
parent
449820ec7d
commit
64bb41f4fa
|
@ -13,7 +13,7 @@ from TTS.tts.utils.text.phonemizers import Gruut
|
||||||
def compute_phonemes(item):
|
def compute_phonemes(item):
|
||||||
text = item["text"]
|
text = item["text"]
|
||||||
ph = phonemizer.phonemize(text).replace("|", "")
|
ph = phonemizer.phonemize(text).replace("|", "")
|
||||||
return set(list(ph))
|
return set(ph)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
|
||||||
ax.set_title("Transition probability of state")
|
ax.set_title("Transition probability of state")
|
||||||
ax.set_xlabel("hidden state")
|
ax.set_xlabel("hidden state")
|
||||||
ax.set_ylabel("probability")
|
ax.set_ylabel("probability")
|
||||||
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
|
ax.set_xticks(list(range(len(transition_probabilities))))
|
||||||
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if not output_fig:
|
if not output_fig:
|
||||||
|
|
|
@ -126,7 +126,7 @@ class CLVP(nn.Module):
|
||||||
text_latents = self.to_text_latent(text_latents)
|
text_latents = self.to_text_latent(text_latents)
|
||||||
speech_latents = self.to_speech_latent(speech_latents)
|
speech_latents = self.to_speech_latent(speech_latents)
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
|
||||||
|
|
||||||
temp = self.temperature.exp()
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
|
|
|
@ -972,7 +972,7 @@ class GaussianDiffusion:
|
||||||
assert False # not currently supported for this type of diffusion.
|
assert False # not currently supported for this type of diffusion.
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
terms.update(dict(zip(model_output_keys, model_outputs)))
|
||||||
model_output = terms[gd_out_key]
|
model_output = terms[gd_out_key]
|
||||||
if self.model_var_type in [
|
if self.model_var_type in [
|
||||||
ModelVarType.LEARNED,
|
ModelVarType.LEARNED,
|
||||||
|
|
|
@ -37,7 +37,7 @@ def route_args(router, args, depth):
|
||||||
for key in matched_keys:
|
for key in matched_keys:
|
||||||
val = args[key]
|
val = args[key]
|
||||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
|
||||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||||
return routed_args
|
return routed_args
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
||||||
softmax = torch.softmax
|
softmax = torch.softmax
|
||||||
|
|
||||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ def init_zero_(layer):
|
||||||
|
|
||||||
|
|
||||||
def pick_and_pop(keys, d):
|
def pick_and_pop(keys, d):
|
||||||
values = list(map(lambda key: d.pop(key), keys))
|
values = [d.pop(key) for key in keys]
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
|
||||||
|
|
||||||
def groupby_prefix_and_trim(prefix, d):
|
def groupby_prefix_and_trim(prefix, d):
|
||||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
|
||||||
feats_per_shift = x.shape[-1] // segments
|
feats_per_shift = x.shape[-1] // segments
|
||||||
splitted = x.split(feats_per_shift, dim=-1)
|
splitted = x.split(feats_per_shift, dim=-1)
|
||||||
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
||||||
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
|
||||||
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
||||||
return self.fn(x, **kwargs)
|
return self.fn(x, **kwargs)
|
||||||
|
|
||||||
|
@ -635,7 +635,7 @@ class Attention(nn.Module):
|
||||||
v = self.to_v(v_input)
|
v = self.to_v(v_input)
|
||||||
|
|
||||||
if not collab_heads:
|
if not collab_heads:
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
|
||||||
else:
|
else:
|
||||||
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
||||||
k = rearrange(k, "b n d -> b () n d")
|
k = rearrange(k, "b n d -> b () n d")
|
||||||
|
@ -650,9 +650,9 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
if exists(rotary_pos_emb) and not has_context:
|
if exists(rotary_pos_emb) and not has_context:
|
||||||
l = rotary_pos_emb.shape[-1]
|
l = rotary_pos_emb.shape[-1]
|
||||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
(ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
|
||||||
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
|
||||||
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if any(map(exists, (mask, context_mask))):
|
if any(map(exists, (mask, context_mask))):
|
||||||
|
@ -664,7 +664,7 @@ class Attention(nn.Module):
|
||||||
input_mask = q_mask * k_mask
|
input_mask = q_mask * k_mask
|
||||||
|
|
||||||
if self.num_mem_kv > 0:
|
if self.num_mem_kv > 0:
|
||||||
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
|
mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
|
||||||
k = torch.cat((mem_k, k), dim=-2)
|
k = torch.cat((mem_k, k), dim=-2)
|
||||||
v = torch.cat((mem_v, v), dim=-2)
|
v = torch.cat((mem_v, v), dim=-2)
|
||||||
if exists(input_mask):
|
if exists(input_mask):
|
||||||
|
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
seq_len += past_key_values[0][0].shape[-2]
|
seq_len += past_key_values[0][0].shape[-2]
|
||||||
max_rotary_emb_length = max(
|
max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
|
||||||
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
|
|
||||||
)
|
|
||||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||||
|
|
||||||
present_key_values = []
|
present_key_values = []
|
||||||
|
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
|
||||||
|
|
||||||
res = [out]
|
res = [out]
|
||||||
if return_attn:
|
if return_attn:
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||||
res.append(attn_maps)
|
res.append(attn_maps)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
||||||
|
|
||||||
res = [out]
|
res = [out]
|
||||||
if return_attn:
|
if return_attn:
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||||
res.append(attn_maps)
|
res.append(attn_maps)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
|
|
@ -260,7 +260,7 @@ class DiscreteVAE(nn.Module):
|
||||||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||||
dec_chans = [dec_init_chan, *dec_chans]
|
dec_chans = [dec_init_chan, *dec_chans]
|
||||||
|
|
||||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
|
||||||
|
|
||||||
pad = (kernel_size - 1) // 2
|
pad = (kernel_size - 1) // 2
|
||||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||||
|
@ -306,9 +306,9 @@ class DiscreteVAE(nn.Module):
|
||||||
if not self.normalization is not None:
|
if not self.normalization is not None:
|
||||||
return images
|
return images
|
||||||
|
|
||||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
|
||||||
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
||||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
means, stds = (rearrange(t, arrange) for t in (means, stds))
|
||||||
images = images.clone()
|
images = images.clone()
|
||||||
images.sub_(means).div_(stds)
|
images.sub_(means).div_(stds)
|
||||||
return images
|
return images
|
||||||
|
|
|
@ -1948,8 +1948,7 @@ class VitsCharacters(BaseCharacters):
|
||||||
def _create_vocab(self):
|
def _create_vocab(self):
|
||||||
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
# pylint: disable=unnecessary-comprehension
|
self._id_to_char = dict(enumerate(self.vocab))
|
||||||
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: Coqpit):
|
def init_from_config(config: Coqpit):
|
||||||
|
@ -1996,4 +1995,4 @@ class FairseqVocab(BaseVocabulary):
|
||||||
self.blank = self._vocab[0]
|
self.blank = self._vocab[0]
|
||||||
self.pad = " "
|
self.pad = " "
|
||||||
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||||
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
self._id_to_char = dict(enumerate(self._vocab))
|
||||||
|
|
|
@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager):
|
||||||
languages.add(dataset["language"])
|
languages.add(dataset["language"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||||
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
return {name: i for i, name in enumerate(sorted(languages))}
|
||||||
|
|
||||||
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||||
"""Set language IDs from config samples.
|
"""Set language IDs from config samples.
|
||||||
|
|
|
@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
embeddings = load_file(file_path)
|
embeddings = load_file(file_path)
|
||||||
speakers = sorted({x["name"] for x in embeddings.values()})
|
speakers = sorted({x["name"] for x in embeddings.values()})
|
||||||
name_to_id = {name: i for i, name in enumerate(speakers)}
|
name_to_id = {name: i for i, name in enumerate(speakers)}
|
||||||
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
|
clip_ids = list(set(clip_name for clip_name in embeddings.keys()))
|
||||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||||
embeddings_by_names = {}
|
embeddings_by_names = {}
|
||||||
for x in embeddings.values():
|
for x in embeddings.values():
|
||||||
|
|
|
@ -87,9 +87,7 @@ class BaseVocabulary:
|
||||||
if vocab is not None:
|
if vocab is not None:
|
||||||
self._vocab = vocab
|
self._vocab = vocab
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||||
self._id_to_char = {
|
self._id_to_char = dict(enumerate(self._vocab))
|
||||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config, **kwargs):
|
def init_from_config(config, **kwargs):
|
||||||
|
@ -269,9 +267,7 @@ class BaseCharacters:
|
||||||
def vocab(self, vocab):
|
def vocab(self, vocab):
|
||||||
self._vocab = vocab
|
self._vocab = vocab
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
self._id_to_char = {
|
self._id_to_char = dict(enumerate(self.vocab))
|
||||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_chars(self):
|
def num_chars(self):
|
||||||
|
|
|
@ -350,8 +350,8 @@ def hira2kata(text: str) -> str:
|
||||||
return text.replace("う゛", "ヴ")
|
return text.replace("う゛", "ヴ")
|
||||||
|
|
||||||
|
|
||||||
_SYMBOL_TOKENS = set(list("・、。?!"))
|
_SYMBOL_TOKENS = set("・、。?!")
|
||||||
_NO_YOMI_TOKENS = set(list("「」『』―()[][] …"))
|
_NO_YOMI_TOKENS = set("「」『』―()[][] …")
|
||||||
_TAGGER = MeCab.Tagger()
|
_TAGGER = MeCab.Tagger()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = dict({})
|
batch = {}
|
||||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||||
|
|
|
@ -266,7 +266,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = dict({})
|
batch = {}
|
||||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||||
|
|
Loading…
Reference in New Issue