diff --git a/TTS/__init__.py b/TTS/__init__.py index 5162d4ec..53d10f1f 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -1,6 +1,6 @@ import os -with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: +with open(os.path.join(os.path.dirname(__file__), "VERSION"), 'r', encoding='utf-8') as f: version = f.read().strip() __version__ = version diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 9637fe75..3a5c067e 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -158,7 +158,7 @@ Example run: # ourput metafile metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") - with open(metafile, "w") as f: + with open(metafile, "w", encoding="utf-8") as f: for p in file_paths: f.write(f"{p[0]}|{p[1]}\n") print(f" >> Metafile created: {metafile}") diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 0eee3083..6ec99fac 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -215,7 +215,7 @@ def extract_spectrograms( wav = ap.inv_melspectrogram(mel) ap.save_wav(wav, wav_gl_path) - with open(os.path.join(output_path, metada_name), "w") as f: + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: for data in export_metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n") diff --git a/TTS/model.py b/TTS/model.py index aefb925e..cfd1ec62 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC): """ @abstractmethod - def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict: + def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: """Forward pass for the model mainly used in training. - You can be flexible here and use different number of arguments and argument names since it is mostly used by - `train_step()` in training whitout exposing it to the out of the class. + You can be flexible here and use different number of arguments and argument names since it is intended to be + used by `train_step()` without exposing it out of the model. Args: - text (torch.Tensor): Input text character sequence ids. + input (torch.Tensor): Input tensor. aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. - for the model. Returns: - Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model. + Dict: Model outputs. Main model output must be named as "model_outputs". """ outputs_dict = {"model_outputs": None} ... return outputs_dict @abstractmethod - def inference(self, text: torch.Tensor, aux_input={}) -> Dict: + def inference(self, input: torch.Tensor, aux_input={}) -> Dict: """Forward pass for inference. - After the model is trained this is the only function that connects the model the out world. - - This function must only take a `text` input and a dictionary that has all the other model specific inputs. We don't use `*kwargs` since it is problematic with the TorchScript API. Args: - text (torch.Tensor): [description] + input (torch.Tensor): [description] aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. Returns: diff --git a/TTS/speaker_encoder/losses.py b/TTS/speaker_encoder/losses.py index ac7e62bf..8ba917b7 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/speaker_encoder/losses.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn # adapted from https://github.com/cvqluu/GE2E-Loss diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index f121631b..fcc850d7 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -1,6 +1,6 @@ import numpy as np import torch -import torch.nn as nn +from torch import nn from TTS.utils.io import load_fsspec diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/speaker_encoder/utils/prepare_voxceleb.py index 05a65bea..b93baf9e 100644 --- a/TTS/speaker_encoder/utils/prepare_voxceleb.py +++ b/TTS/speaker_encoder/utils/prepare_voxceleb.py @@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls): extract_path = zip_filepath.strip(".zip") # check zip file md5sum - md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest() + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() if md5 != MD5SUM[subset]: raise ValueError("md5sum of %s mismatch" % zip_filepath) diff --git a/TTS/trainer.py b/TTS/trainer.py index b22595b3..3c0784f7 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -631,13 +631,13 @@ class Trainer: outputs = outputs_per_optimizer # update avg runtime stats - keep_avg_update = dict() + keep_avg_update = {} keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_step_time"] = step_time self.keep_avg_train.update_values(keep_avg_update) # update avg loss stats - update_eval_values = dict() + update_eval_values = {} for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value self.keep_avg_train.update_values(update_eval_values) @@ -797,7 +797,7 @@ class Trainer: loss_dict = self._detach_loss_dict(loss_dict) # update avg stats - update_eval_values = dict() + update_eval_values = {} for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value self.keep_avg_eval.update_values(update_eval_values) @@ -977,12 +977,13 @@ class Trainer: def __init__(self, print_to_terminal=True): self.print_to_terminal = print_to_terminal self.terminal = sys.stdout - self.log = open(log_file, "a") + self.log_file = log_file def write(self, message): if self.print_to_terminal: self.terminal.write(message) - self.log.write(message) + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(message) def flush(self): # this flush method is needed for python 3 compatibility. diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index cbae78a7..a2520751 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True): def load_attention_mask_meta_data(metafile_path): """Load meta data file created by compute_attention_masks.py""" - with open(metafile_path, "r") as f: + with open(metafile_path, "r", encoding="utf-8") as f: lines = f.readlines() meta_data = [] diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index c057c51e..eee407a8 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -19,7 +19,7 @@ def tweb(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "tweb" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") @@ -33,7 +33,7 @@ def mozilla(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = cols[1].strip() @@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None): continue speaker_name = speaker_name_match.group("speaker_name") print(" | > {}".format(csv_file)) - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") if meta_files is None: @@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file): for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") - text = cols[1] + text = cols[2] items.append([text, wav_file, speaker_name]) return items @@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file): for idx, line in enumerate(ttf): cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") - text = cols[1] + text = cols[2] items.append([text, wav_file, f"ljspeech-{idx}"]) return items @@ -158,7 +158,7 @@ def css10(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) @@ -172,7 +172,7 @@ def nancy(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "nancy" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: utt_id = line.split()[1] text = line[line.find('"') + 1 : line.rfind('"') - 1] @@ -185,7 +185,7 @@ def common_voice(root_path, meta_file): """Normalize the common voice meta data file to TTS format.""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: if line.startswith("client_id"): continue @@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None): for meta_file in meta_files: _meta_file = os.path.basename(meta_file).split(".")[0] - with open(meta_file, "r") as ttf: + with open(meta_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("\t") file_name = cols[0] @@ -245,7 +245,7 @@ def brspeech(root_path, meta_file): """BRSpeech 3.0 beta""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: if line.startswith("wav_filename"): continue @@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"): if isinstance(test_speakers, list): # if is list ignore this speakers ids if speaker_id in test_speakers: continue - with open(meta_file) as file_text: + with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") items.append([text, wav_file, "VCTK_" + speaker_id]) @@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): def mls(root_path, meta_files=None): """http://www.openslr.org/94/""" items = [] - with open(os.path.join(root_path, meta_files), "r") as meta: + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: for line in meta: file, text = line.split("\t") text = text[:-1] @@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): # if not exists meta file, crawl recursively for 'wav' files if meta_file is not None: - with open(str(meta_file), "r") as f: + with open(str(meta_file), "r", encoding="utf-8") as f: return [x.strip().split("|") for x in f.readlines()] elif not cache_to.exists(): @@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): text = None # VoxCel does not provide transciptions, and they are not needed for training the SE meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") cnt += 1 - with open(str(cache_to), "w") as f: + with open(str(cache_to), "w", encoding="utf-8") as f: f.write("".join(meta_data)) if cnt < expected_count: raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") - with open(str(cache_to), "r") as f: + with open(str(cache_to), "r", encoding="utf-8") as f: return [x.strip().split("|") for x in f.readlines()] @@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "baker" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) @@ -380,7 +380,7 @@ def kokoro(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "kokoro" - with open(txt_file, "r") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 75631e0e..9e6b69ac 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn class FFTransformer(nn.Module): diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 0d3ed039..01a81e0b 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn class GST(nn.Module): diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 47b5ea7e..bddaf449 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -388,8 +388,8 @@ class Decoder(nn.Module): decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) # Pass through the decoder RNNs - for idx in range(len(self.decoder_rnns)): - self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) + for idx, decoder_rnn in enumerate(self.decoder_rnns): + self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index fb2fa697..2aa84cb2 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field from typing import Dict, Tuple import torch -import torch.nn as nn from coqpit import Coqpit +from torch import nn from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.feed_forward.decoder import Decoder diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3b9d82f9..4c0b4667 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -435,7 +435,7 @@ class Vits(BaseTTS): attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, ) - loss_duration = loss_duration/ torch.sum(x_mask) + loss_duration = loss_duration / torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( @@ -579,7 +579,7 @@ class Vits(BaseTTS): scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], - loss_duration=outputs["loss_duration"] + loss_duration=outputs["loss_duration"], ) elif optimizer_idx == 1: diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py index b1cb4420..a51ef766 100644 --- a/TTS/utils/distribute.py +++ b/TTS/utils/distribute.py @@ -18,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): # Initialize distributed communication dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) - - -def apply_gradient_allreduce(module): - - # sync model parameters - for p in module.state_dict().values(): - if not torch.is_tensor(p): - continue - dist.broadcast(p, 0) - - def allreduce_params(): - if module.needs_reduction: - module.needs_reduction = False - # bucketing params based on value types - buckets = {} - for param in module.parameters(): - if param.requires_grad and param.grad is not None: - tp = type(param.data) - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(param) - for tp in buckets: - bucket = buckets[tp] - grads = [param.grad.data for param in bucket] - coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.reduce_op.SUM) - coalesced /= dist.get_world_size() - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - for param in list(module.parameters()): - - def allreduce_hook(*_): - Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access - - if param.requires_grad: - param.register_hook(allreduce_hook) - - def set_needs_reduction(self, *_): - self.needs_reduction = True - - module.register_forward_hook(set_needs_reduction) - return module diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 83cd4233..24b905f9 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from torch.nn.utils import weight_norm diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index a1e16150..4ce743b3 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,8 +1,8 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm from TTS.utils.io import load_fsspec diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index 33e0a688..b4909f37 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module): ) def forward(self, x): - scores = list() - feats = list() + scores = [] + feats = [] for disc in self.discriminators: score, feat = disc(x) scores.append(score) diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py index d99b2760..d6b0e5d5 100644 --- a/TTS/vocoder/models/univnet_discriminator.py +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -1,6 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from torch.nn.utils import spectral_norm, weight_norm from TTS.utils.audio import TorchSTFT diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 8a968019..9b0d6837 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -5,9 +5,9 @@ from typing import Dict, List, Tuple import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit +from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py index 6c6bc582..9ba42fb9 100644 --- a/notebooks/dataset_analysis/analyze.py +++ b/notebooks/dataset_analysis/analyze.py @@ -43,7 +43,7 @@ def process_meta_data(path): meta_data = {} # load meta data - with open(path, "r") as f: + with open(path, "r", encoding="utf-8") as f: data = csv.reader(f, delimiter="|") for row in data: frames = int(row[2]) @@ -92,7 +92,7 @@ def save_training(file_path, meta_data): rows.append(d["row"] + "\n") random.shuffle(rows) - with open(file_path, "w+") as f: + with open(file_path, "w+", encoding="utf-8") as f: for row in rows: f.write(row) @@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): phonemes = {} - with open(train_path, "r") as f: + with open(train_path, "r", encoding="utf-8") as f: data = csv.reader(f, delimiter="|") phonemes["None"] = 0 for row in data: @@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): phonemes["None"] += 1 x, y = [], [] - for key in phonemes: - x.append(key) - y.append(phonemes[key]) + for k, v in phonemes.items(): + x.append(k) + y.append(v) plt.figure() plt.rcParams["figure.figsize"] = (50, 20) diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index 8bff0ab7..7cf52f89 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -29,7 +29,7 @@ config = VitsConfig( run_name="vits_ljspeech", batch_size=48, eval_batch_size=16, - batch_group_size=0, + batch_group_size=5, num_loader_workers=4, num_eval_loader_workers=4, run_eval=True, diff --git a/requirements.dev.txt b/requirements.dev.txt index afb5ebe6..c995f9e6 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,4 +2,4 @@ black coverage isort nose -pylint==2.8.3 +pylint==2.10.2