From 09bbec073dd056dbe366cde187ccc6a05736b6dd Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Thu, 3 May 2018 15:16:57 -0700 Subject: [PATCH] adding python files --- audio_processing.py | 93 ++++++++ data_utils.py | 101 +++++++++ distributed.py | 120 ++++++++++ fp16_optimizer.py | 381 +++++++++++++++++++++++++++++++ hparams.py | 91 ++++++++ layers.py | 80 +++++++ logger.py | 48 ++++ loss_function.py | 19 ++ loss_scaler.py | 132 +++++++++++ model.py | 541 ++++++++++++++++++++++++++++++++++++++++++++ multiproc.py | 23 ++ plotting_utils.py | 61 +++++ stft.py | 140 ++++++++++++ train.py | 272 ++++++++++++++++++++++ utils.py | 32 +++ 15 files changed, 2134 insertions(+) create mode 100644 audio_processing.py create mode 100644 data_utils.py create mode 100644 distributed.py create mode 100644 fp16_optimizer.py create mode 100644 hparams.py create mode 100644 layers.py create mode 100644 logger.py create mode 100644 loss_function.py create mode 100644 loss_scaler.py create mode 100644 model.py create mode 100644 multiproc.py create mode 100644 plotting_utils.py create mode 100644 stft.py create mode 100644 train.py create mode 100644 utils.py diff --git a/audio_processing.py b/audio_processing.py new file mode 100644 index 0000000..b5af7f7 --- /dev/null +++ b/audio_processing.py @@ -0,0 +1,93 @@ +import torch +import numpy as np +from scipy.signal import get_window +import librosa.util as librosa_util + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000..786a282 --- /dev/null +++ b/data_utils.py @@ -0,0 +1,101 @@ +import random +import torch +import torch.utils.data + +import layers +from utils import load_wav_to_torch, load_filepaths_and_text +from text import text_to_sequence + + +class TextMelLoader(torch.utils.data.Dataset): + """ + 1) loads audio,text pairs + 2) normalizes text and converts them to sequences of one-hot vectors + 3) computes mel-spectrograms from audio files. + """ + def __init__(self, audiopaths_and_text, hparams, shuffle=True): + self.audiopaths_and_text = load_filepaths_and_text( + audiopaths_and_text, hparams.sort_by_length) + self.text_cleaners = hparams.text_cleaners + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.stft = layers.TacotronSTFT( + hparams.filter_length, hparams.hop_length, hparams.win_length, + hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, + hparams.mel_fmax) + random.seed(1234) + if shuffle: + random.shuffle(self.audiopaths_and_text) + + def get_mel_text_pair(self, audiopath_and_text): + # separate filename and text + audiopath, text = audiopath_and_text[0], audiopath_and_text[1] + text = self.get_text(text) + mel = self.get_mel(audiopath) + return (text, mel) + + def get_mel(self, filename): + audio = load_wav_to_torch(filename, self.sampling_rate) + audio_norm = audio / self.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + melspec = self.stft.mel_spectrogram(audio_norm) + melspec = torch.squeeze(melspec, 0) + return melspec + + def get_text(self, text): + text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) + return text_norm + + def __getitem__(self, index): + return self.get_mel_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextMelCollate(): + """ Zero-pads model inputs and targets based on number of frames per setep + """ + def __init__(self, n_frames_per_step): + self.n_frames_per_step = n_frames_per_step + + def __call__(self, batch): + """Collate's training batch from normalized text and mel-spectrogram + PARAMS + ------ + batch: [text_normalized, mel_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + input_lengths, ids_sorted_decreasing = torch.sort( + torch.LongTensor([len(x[0]) for x in batch]), + dim=0, descending=True) + max_input_len = input_lengths[0] + + text_padded = torch.LongTensor(len(batch), max_input_len) + text_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + text = batch[ids_sorted_decreasing[i]][0] + text_padded[i, :text.size(0)] = text + + # Right zero-pad mel-spec with extra single zero vector to mark the end + num_mels = batch[0][1].size(0) + max_target_len = max([x[1].size(1) for x in batch]) + 1 + if max_target_len % self.n_frames_per_step != 0: + max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step + assert max_target_len % self.n_frames_per_step == 0 + + # include mel padded and gate padded + mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) + mel_padded.zero_() + gate_padded = torch.FloatTensor(len(batch), max_target_len) + gate_padded.zero_() + output_lengths = torch.LongTensor(len(batch)) + for i in range(len(ids_sorted_decreasing)): + mel = batch[ids_sorted_decreasing[i]][1] + mel_padded[i, :, :mel.size(1)] = mel + gate_padded[i, mel.size(1):] = 1 + output_lengths[i] = mel.size(1) + + return text_padded, input_lengths, mel_padded, gate_padded, \ + output_lengths diff --git a/distributed.py b/distributed.py new file mode 100644 index 0000000..ebe3b5b --- /dev/null +++ b/distributed.py @@ -0,0 +1,120 @@ +import torch +import torch.distributed as dist +from torch.nn.modules import Module + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + Arguments: + tensors (Iterable[Tensor]): dense tensors to flatten. + Returns: + A contiguous 1D buffer containing input tensors. + """ + if len(tensors) == 1: + return tensors[0].contiguous().view(-1) + flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) + return flat + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + Arguments: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return tuple(outputs) + + +''' +This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py +launcher included with this example. It assumes that your run is using multiprocess with 1 +GPU/process, that the model is on the correct device, and that torch.set_device has been +used to set the device. + +Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, +and will be allreduced at the finish of the backward pass. +''' +class DistributedDataParallel(Module): + + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + #fallback for PyTorch 0.3 + if not hasattr(dist, '_backend'): + self.warn_on_half = True + else: + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + + for p in self.module.state_dict().values(): + if not torch.is_tensor(p): + continue + dist.broadcast(p, 0) + + def allreduce_params(): + if(self.needs_reduction): + self.needs_reduction = False + buckets = {} + for param in self.module.parameters(): + if param.requires_grad and param.grad is not None: + tp = type(param.data) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case. This currently requires" + + "PyTorch built from top of tree master.") + self.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + for param in list(self.module.parameters()): + def allreduce_hook(*unused): + param._execution_engine.queue_callback(allreduce_params) + if param.requires_grad: + param.register_hook(allreduce_hook) + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) + + ''' + def _sync_buffers(self): + buffers = list(self.module._all_buffers()) + if len(buffers) > 0: + # cross-node buffer sync + flat_buffers = _flatten_dense_tensors(buffers) + dist.broadcast(flat_buffers, 0) + for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): + buf.copy_(synced) + def train(self, mode=True): + # Clear NCCL communicator and CUDA event cache of the default group ID, + # These cache will be recreated at the later call. This is currently a + # work-around for a potential NCCL deadlock. + if dist._backend == dist.dist_backend.NCCL: + dist._clear_group_cache() + super(DistributedDataParallel, self).train(mode) + self.module.train(mode) + ''' diff --git a/fp16_optimizer.py b/fp16_optimizer.py new file mode 100644 index 0000000..7f1c57a --- /dev/null +++ b/fp16_optimizer.py @@ -0,0 +1,381 @@ +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from loss_scaler import DynamicLossScaler, LossScaler + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + return conversion_helper(val, half_conversion) + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + return conversion_helper(val, float_conversion) + +class FP16_Module(nn.Module): + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + +class FP16_Optimizer(object): + """ + FP16_Optimizer is designed to wrap an existing PyTorch optimizer, + and enable an fp16 model to be trained using a master copy of fp32 weights. + + Args: + optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, FP16_Optimizer replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step. + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate. + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any static_loss_scale option. + + """ + + def __init__(self, optimizer, static_loss_scale=1.0, dynamic_loss_scale=False): + if not torch.cuda.is_available: + raise SystemError('Cannot use fp16 without CUDA') + + self.fp16_param_groups = [] + self.fp32_param_groups = [] + self.fp32_flattened_groups = [] + for i, param_group in enumerate(optimizer.param_groups): + print("FP16_Optimizer processing param group {}:".format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + for param in param_group['params']: + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + print("FP16_Optimizer received torch.cuda.HalfTensor with {}" + .format(param.size())) + fp16_params_this_group.append(param) + elif param.type() == 'torch.cuda.FloatTensor': + print("FP16_Optimizer received torch.cuda.FloatTensor with {}" + .format(param.size())) + fp32_params_this_group.append(param) + else: + raise TypeError("Wrapped parameters must be either " + "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + "Received {}".format(param.type())) + + fp32_flattened_this_group = None + if len(fp16_params_this_group) > 0: + fp32_flattened_this_group = _flatten_dense_tensors( + [param.detach().data.clone().float() for param in fp16_params_this_group]) + + fp32_flattened_this_group = Variable(fp32_flattened_this_group, requires_grad = True) + + fp32_flattened_this_group.grad = fp32_flattened_this_group.new( + *fp32_flattened_this_group.size()) + + # python's lovely list concatenation via + + if fp32_flattened_this_group is not None: + param_group['params'] = [fp32_flattened_this_group] + fp32_params_this_group + else: + param_group['params'] = fp32_params_this_group + + self.fp16_param_groups.append(fp16_params_this_group) + self.fp32_param_groups.append(fp32_params_this_group) + self.fp32_flattened_groups.append(fp32_flattened_this_group) + + # print("self.fp32_flattened_groups = ", self.fp32_flattened_groups) + # print("self.fp16_param_groups = ", self.fp16_param_groups) + + self.optimizer = optimizer.__class__(optimizer.param_groups) + + # self.optimizer.load_state_dict(optimizer.state_dict()) + + self.param_groups = self.optimizer.param_groups + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + def zero_grad(self): + """ + Zero fp32 and fp16 parameter grads. + """ + self.optimizer.zero_grad() + for fp16_group in self.fp16_param_groups: + for param in fp16_group: + if param.grad is not None: + param.grad.detach_() # This does appear in torch.optim.optimizer.zero_grad(), + # but I'm not sure why it's needed. + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_param_groups: + for param in group: + params.append(param) + for group in self.fp32_param_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _copy_grads_fp16_to_fp32(self): + for fp32_group, fp16_group in zip(self.fp32_flattened_groups, self.fp16_param_groups): + if len(fp16_group) > 0: + # This might incur one more deep copy than is necessary. + fp32_group.grad.data.copy_( + _flatten_dense_tensors([fp16_param.grad.data for fp16_param in fp16_group])) + + def _downscale_fp32(self): + if self.loss_scale != 1.0: + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + param.grad.data.mul_(1./self.loss_scale) + + def clip_fp32_grads(self, clip=-1): + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + if clip > 0: + return torch.nn.utils.clip_grad_norm(fp32_params, clip) + + def _copy_params_fp32_to_fp16(self): + for fp16_group, fp32_group in zip(self.fp16_param_groups, self.fp32_flattened_groups): + if len(fp16_group) > 0: + for fp16_param, fp32_data in zip(fp16_group, + _unflatten_dense_tensors(fp32_group.data, fp16_group)): + fp16_param.data.copy_(fp32_data) + + def state_dict(self): + """ + Returns a dict containing the current state of this FP16_Optimizer instance. + This dict contains attributes of FP16_Optimizer, as well as the state_dict + of the contained Pytorch optimizer. + + Untested. + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict. + + Untested. + """ + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss). + step updates the fp32 master copy of parameters using the optimizer supplied to + FP16_Optimizer's constructor, then copies the updated fp32 params into the fp16 params + originally referenced by Fp16_Optimizer's constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, step may be called without a prior call to self.backward(loss). + However, the user should take care that any loss.backward() call within the closure + has been replaced by fp16_optimizer_obj.backward(loss). + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer's constructor. closure should call zero_grad on the FP16_Optimizer object, compute the loss, call .backward(loss), and return the loss. + + Closure example:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. note:: + The only changes that need to be made compared to + `ordinary optimizer closures`_ are that "optimizer" itself should be an instance of + FP16_Optimizer, and that the call to loss.backward should be replaced by + optimizer.backward(loss). + + .. warning:: + Currently, calling step with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary optimizer closures`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler): + raise TypeError("Using step with a closure is currently not " + "compatible with dynamic loss scaling.") + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + print("OVERFLOW! Skipping step. Attempted loss scale: {}".format(scale)) + return + + if closure is not None: + self._step_with_closure(closure) + else: + self.optimizer.step() + + self._copy_params_fp32_to_fp16() + + return + + def _step_with_closure(self, closure): + def wrapped_closure(): + if self.first_closure_call_this_step: + """ + We expect that the fp16 params are initially fresh on entering self.step(), + so _copy_params_fp32_to_fp16() is unnecessary the first time wrapped_closure() + is called within self.optimizer.step(). + """ + self.first_closure_call_this_step = False + else: + """ + If self.optimizer.step() internally calls wrapped_closure more than once, + it may update the fp32 params after each call. However, self.optimizer + doesn't know about the fp16 params at all. If the fp32 params get updated, + we can't rely on self.optimizer to refresh the fp16 params. We need + to handle that manually: + """ + self._copy_params_fp32_to_fp16() + + """ + Our API expects the user to give us ownership of the backward() call by + replacing all calls to loss.backward() with optimizer.backward(loss). + This requirement holds whether or not the call to backward() is made within + a closure. + If the user is properly calling optimizer.backward(loss) within "closure," + calling closure() here will give the fp32 master params fresh gradients + for the optimizer to play with, + so all wrapped_closure needs to do is call closure() and return the loss. + """ + temp_loss = closure() + return temp_loss + + self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + def backward(self, loss, update_fp32_grads=True): + """ + fp16_optimizer_obj.backward performs the following conceptual operations: + + fp32_loss = loss.float() (see first Note below) + + scaled_loss = fp32_loss*loss_scale + + scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the + fp16 model's leaves. + + fp16 grads are then copied to the stored fp32 params' .grad attributes (see second Note). + + Finally, fp32 grads are divided by loss_scale. + + In this way, after fp16_optimizer_obj.backward, the fp32 parameters have fresh gradients, + and fp16_optimizer_obj.step may be called. + + .. note:: + Converting the loss to fp32 before applying the loss scale provides some + additional safety against overflow if the user has supplied an fp16 value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + fp16_optimizer_obj.backward. + + .. note:: + The gradients found in an fp16 model's leaves after a call to + fp16_optimizer_obj.backward should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may silently change over time). + If the user wants to inspect gradients after a call to fp16_optimizer_obj.backward, + he/she should query the .grad attribute of FP16_Optimizer's stored fp32 parameters. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_fp32_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay this copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling fp16_optimizer_obj.update_fp32_grads before calling fp16_optimizer_obj.step. + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_fp32_grads=False) + optimizer.backward(loss2, update_fp32_grads=False) + optimizer.update_fp32_grads() + """ + self.loss_scaler.backward(loss.float()) + if update_fp32_grads: + self.update_fp32_grads() + + def update_fp32_grads(self): + """ + Copy the .grad attribute from stored references to fp16 parameters to + the .grad attribute of the master fp32 parameters that are directly + updated by the optimizer. :attr:`update_fp32_grads` only needs to be called if + fp16_optimizer_obj.backward was called with update_fp32_grads=False. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return + self._copy_grads_fp16_to_fp32() + self._downscale_fp32() + + @property + def loss_scale(self): + return self.loss_scaler.loss_scale diff --git a/hparams.py b/hparams.py new file mode 100644 index 0000000..2165eca --- /dev/null +++ b/hparams.py @@ -0,0 +1,91 @@ +import tensorflow as tf +from text import symbols + + +def create_hparams(hparams_string=None, verbose=False): + """Create model hyperparameters. Parse nondefault from given string.""" + + hparams = tf.contrib.training.HParams( + ################################ + # Experiment Parameters # + ################################ + epochs=500, + iters_per_checkpoint=500, + seed=1234, + dynamic_loss_scaling=True, + fp16_run=False, + distributed_run=False, + dist_backend="nccl", + dist_url="file://distributed.dpt", + cudnn_enabled=True, + cudnn_benchmark=False, + + ################################ + # Data Parameters # + ################################ + training_files='ljs_audio_text_train_filelist.txt', + validation_files='ljs_audio_text_val_filelist.txt', + text_cleaners=['english_cleaners'], + sort_by_length=False, + + ################################ + # Audio Parameters # + ################################ + max_wav_value=32768.0, + sampling_rate=22050, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0.0, + mel_fmax=None, # if None, half the sampling rate + + ################################ + # Model Parameters # + ################################ + n_symbols=len(symbols), + symbols_embedding_dim=512, + + # Encoder parameters + encoder_kernel_size=5, + encoder_n_convolutions=3, + encoder_embedding_dim=512, + + # Decoder parameters + n_frames_per_step=1, + decoder_rnn_dim=1024, + prenet_dim=256, + max_decoder_steps=1000, + gate_threshold=0.6, + + # Attention parameters + attention_rnn_dim=1024, + attention_dim=128, + + # Location Layer parameters + attention_location_n_filters=32, + attention_location_kernel_size=31, + + # Mel-post processing network parameters + postnet_embedding_dim=512, + postnet_kernel_size=5, + postnet_n_convolutions=5, + + ################################ + # Optimization Hyperparameters # + ################################ + learning_rate=1e-3, + weight_decay=1e-6, + grad_clip_thresh=1, + batch_size=48, + mask_padding=False # set model's padded outputs to padded values + ) + + if hparams_string: + tf.logging.info('Parsing command line hparams: %s', hparams_string) + hparams.parse(hparams_string) + + if verbose: + tf.logging.info('Final parsed hparams: %s', hparams.values()) + + return hparams diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..f4935d5 --- /dev/null +++ b/layers.py @@ -0,0 +1,80 @@ +import torch +from librosa.filters import mel as librosa_mel_fn +from audio_processing import dynamic_range_compression +from audio_processing import dynamic_range_decompression +from stft import STFT + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +class TacotronSTFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=256, win_length=1024, + n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, + mel_fmax=None): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer('mel_basis', mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert(torch.min(y.data) >= -1) + assert(torch.max(y.data) <= 1) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + return mel_output diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..7be5e90 --- /dev/null +++ b/logger.py @@ -0,0 +1,48 @@ +import random +import torch.nn.functional as F +from tensorboardX import SummaryWriter +from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy +from plotting_utils import plot_gate_outputs_to_numpy + + +class Tacotron2Logger(SummaryWriter): + def __init__(self, logdir): + super(Tacotron2Logger, self).__init__(logdir) + + def log_training(self, reduced_loss, grad_norm, learning_rate, duration, + iteration): + self.add_scalar("training.loss", reduced_loss, iteration) + self.add_scalar("grad.norm", grad_norm, iteration) + self.add_scalar("learning.rate", learning_rate, iteration) + self.add_scalar("duration", duration, iteration) + + def log_validation(self, reduced_loss, model, y, y_pred, iteration): + self.add_scalar("validation.loss", reduced_loss, iteration) + _, mel_outputs, gate_outputs, alignments = y_pred + mel_targets, gate_targets = y + + # plot distribution of parameters + for tag, value in model.named_parameters(): + tag = tag.replace('.', '/') + self.add_histogram(tag, value.data.cpu().numpy(), iteration) + + # plot alignment, mel target and predicted, gate target and predicted + idx = random.randint(0, alignments.size(0) - 1) + self.add_image( + "alignment", + plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), + iteration) + self.add_image( + "mel_target", + plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), + iteration) + self.add_image( + "mel_predicted", + plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), + iteration) + self.add_image( + "gate", + plot_gate_outputs_to_numpy( + gate_targets[idx].data.cpu().numpy(), + F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), + iteration) diff --git a/loss_function.py b/loss_function.py new file mode 100644 index 0000000..99cae95 --- /dev/null +++ b/loss_function.py @@ -0,0 +1,19 @@ +from torch import nn + + +class Tacotron2Loss(nn.Module): + def __init__(self): + super(Tacotron2Loss, self).__init__() + + def forward(self, model_output, targets): + mel_target, gate_target = targets[0], targets[1] + mel_target.requires_grad = False + gate_target.requires_grad = False + gate_target = gate_target.view(-1, 1) + + mel_out, mel_out_postnet, gate_out, _ = model_output + gate_out = gate_out.view(-1, 1) + mel_loss = nn.MSELoss()(mel_out, mel_target) + \ + nn.MSELoss()(mel_out_postnet, mel_target) + gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) + return mel_loss + gate_loss diff --git a/loss_scaler.py b/loss_scaler.py new file mode 100644 index 0000000..c7dfa13 --- /dev/null +++ b/loss_scaler.py @@ -0,0 +1,132 @@ +import torch + +class LossScaler: + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + # `overflow` is boolean indicating whether we overflowed in gradient + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss*self.loss_scale + scaled_loss.backward() + +class DynamicLossScaler: + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): +# return False + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): + return True + + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + inf_count = torch.sum(x.abs() == float('inf')) + if inf_count > 0: + return True + nan_count = torch.sum(x != x) + return nan_count > 0 + + # `overflow` is boolean indicating whether we overflowed in gradient + def update_scale(self, overflow): + if overflow: + #self.cur_scale /= self.scale_factor + self.cur_scale = max(self.cur_scale/self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + self.cur_scale *= self.scale_factor +# self.cur_scale = 1 + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss*self.loss_scale + scaled_loss.backward() + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + diff --git a/model.py b/model.py new file mode 100644 index 0000000..1f9e7d1 --- /dev/null +++ b/model.py @@ -0,0 +1,541 @@ +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F +from layers import ConvNorm, LinearNorm +from utils import to_gpu, get_mask_from_lengths +from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm(2, attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, bias=False, stride=1, + dilation=1) + self.location_dense = LinearNorm(attention_n_filters, attention_dim, + bias=False, w_init_gain='tanh') + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(Attention, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(F.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class Prenet(nn.Module): + def __init__(self, in_dim, sizes): + super(Prenet, self).__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [LinearNorm(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes)]) + + def forward(self, x): + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class Postnet(nn.Module): + """Postnet + - Five 1-d convolution with 512 channels and kernel size 5 + """ + + def __init__(self, hparams): + super(Postnet, self).__init__() + self.dropout = nn.Dropout(0.5) + self.convolutions = nn.ModuleList() + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + for i in range(1, hparams.postnet_n_convolutions - 1): + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, + hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.BatchNorm1d(hparams.n_mel_channels)) + ) + + def forward(self, x): + for i in range(len(self.convolutions) - 1): + x = self.dropout(F.tanh(self.convolutions[i](x))) + + x = self.dropout(self.convolutions[-1](x)) + + return x + + +class Encoder(nn.Module): + """Encoder module: + - Three 1-d convolution banks + - Bidirectional LSTM + """ + def __init__(self, hparams): + super(Encoder, self).__init__() + self.dropout = nn.Dropout(0.5) + + convolutions = [] + for _ in range(hparams.encoder_n_convolutions): + conv_layer = nn.Sequential( + ConvNorm(hparams.encoder_embedding_dim, + hparams.encoder_embedding_dim, + kernel_size=hparams.encoder_kernel_size, stride=1, + padding=int((hparams.encoder_kernel_size - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.BatchNorm1d(hparams.encoder_embedding_dim)) + convolutions.append(conv_layer) + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM(hparams.encoder_embedding_dim, + int(hparams.encoder_embedding_dim / 2), 1, + batch_first=True, bidirectional=True) + + def forward(self, x, input_lengths): + for conv in self.convolutions: + x = self.dropout(F.relu(conv(x))) + + x = x.transpose(1, 2) + + # pytorch tensor are not reversible, hence the conversion + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + outputs, _ = nn.utils.rnn.pad_packed_sequence( + outputs, batch_first=True) + + return outputs + + def inference(self, x): + for conv in self.convolutions: + x = self.dropout(F.relu(conv(x))) + + x = x.transpose(1, 2) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + return outputs + + +class Decoder(nn.Module): + def __init__(self, hparams): + super(Decoder, self).__init__() + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + self.encoder_embedding_dim = hparams.encoder_embedding_dim + self.attention_rnn_dim = hparams.attention_rnn_dim + self.decoder_rnn_dim = hparams.decoder_rnn_dim + self.prenet_dim = hparams.prenet_dim + self.max_decoder_steps = hparams.max_decoder_steps + self.gate_threshold = hparams.gate_threshold + + self.prenet = Prenet( + hparams.n_mel_channels * hparams.n_frames_per_step, + [hparams.prenet_dim, hparams.prenet_dim]) + + self.attention_rnn = nn.LSTMCell( + hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.attention_rnn_dim) + + self.attention_layer = Attention( + hparams.attention_rnn_dim, hparams.encoder_embedding_dim, + hparams.attention_dim, hparams.attention_location_n_filters, + hparams.attention_location_kernel_size) + + self.decoder_rnn = nn.LSTMCell( + hparams.attention_rnn_dim + hparams.encoder_embedding_dim, + hparams.decoder_rnn_dim, 1) + + self.linear_projection = LinearNorm( + hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, + hparams.n_mel_channels*hparams.n_frames_per_step) + + self.gate_layer = LinearNorm( + hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, + bias=True, w_init_gain='sigmoid') + + def get_go_frame(self, memory): + """ Gets all zeros frames to use as first decoder input + PARAMS + ------ + memory: decoder outputs + + RETURNS + ------- + decoder_input: all zeros frames + """ + B = memory.size(0) + decoder_input = Variable(memory.data.new( + B, self.n_mel_channels * self.n_frames_per_step).zero_()) + return decoder_input + + def initialize_decoder_states(self, memory, mask): + """ Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory + PARAMS + ------ + memory: Encoder outputs + mask: Mask for padded data if training, expects None for inference + """ + B = memory.size(0) + MAX_TIME = memory.size(1) + + self.attention_hidden = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + self.attention_cell = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + + self.decoder_hidden = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + self.decoder_cell = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + + self.attention_weights = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_context = Variable(memory.data.new( + B, self.encoder_embedding_dim).zero_()) + + self.memory = memory + self.processed_memory = self.attention_layer.memory_layer(memory) + self.mask = mask + + def parse_decoder_inputs(self, decoder_inputs): + """ Prepares decoder inputs, i.e. mel outputs + PARAMS + ------ + decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs + + RETURNS + ------- + inputs: processed decoder inputs + + """ + # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1)/self.n_frames_per_step), -1) + # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): + """ Prepares decoder outputs for output + PARAMS + ------ + mel_outputs: + gate_outputs: gate output energies + alignments: + + RETURNS + ------- + mel_outputs: + gate_outpust: gate output energies + alignments: + """ + # (T_out, B) -> (B, T_out) + alignments = torch.stack(alignments).transpose(0, 1) + # (T_out, B) -> (B, T_out) + gate_outputs = torch.stack(gate_outputs).transpose(0, 1) + gate_outputs = gate_outputs.contiguous() + # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) + mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() + # decouple frames per step + mel_outputs = mel_outputs.view( + mel_outputs.size(0), -1, self.n_mel_channels) + # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) + mel_outputs = mel_outputs.transpose(1, 2) + + return mel_outputs, gate_outputs, alignments + + def decode(self, decoder_input): + """ Decoder step using stored states, attention and memory + PARAMS + ------ + decoder_input: previous mel output + + RETURNS + ------- + mel_output: + gate_output: gate output energies + attention_weights: + """ + + decoder_input = self.prenet(decoder_input) + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + + attention_weights_cat = torch.cat( + (self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), dim=1) + self.attention_context, self.attention_weights = self.attention_layer( + self.attention_hidden, self.memory, self.processed_memory, + attention_weights_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_input, (self.decoder_hidden, self.decoder_cell)) + + decoder_hidden_attention_context = torch.cat( + (self.decoder_hidden, self.attention_context), dim=1) + decoder_output = self.linear_projection( + decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + return decoder_output, gate_prediction, self.attention_weights + + def forward(self, memory, decoder_inputs, memory_lengths): + """ Decoder forward pass for training + PARAMS + ------ + memory: Encoder outputs + decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs + memory_lengths: Encoder output lengths for attention masking. + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + + decoder_input = self.get_go_frame(memory) + decoder_inputs = self.parse_decoder_inputs(decoder_inputs) + self.initialize_decoder_states( + memory, mask=~get_mask_from_lengths(memory_lengths)) + + mel_outputs, gate_outputs, alignments = [], [], [] + + while len(mel_outputs) < decoder_inputs.size(0): + mel_output, gate_output, attention_weights = self.decode( + decoder_input) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [attention_weights] + + decoder_input = decoder_inputs[len(mel_outputs) - 1] + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + return mel_outputs, gate_outputs, alignments + + def inference(self, memory): + """ Decoder inference + PARAMS + ------ + memory: Encoder outputs + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, mask=None) + + mel_outputs, gate_outputs, alignments = [], [], [] + + while True: + mel_output, gate_output, alignment = self.decode(decoder_input) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [alignment] + + if F.sigmoid(gate_output.data) > self.gate_threshold: + break + elif len(mel_outputs) == self.max_decoder_steps: + print("Warning! Reached max decoder steps") + break + + decoder_input = mel_output + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + return mel_outputs, gate_outputs, alignments + + +class Tacotron2(nn.Module): + def __init__(self, hparams): + super(Tacotron2, self).__init__() + self.mask_padding = hparams.mask_padding + self.fp16_run = hparams.fp16_run + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + self.embedding = nn.Embedding( + hparams.n_symbols, hparams.symbols_embedding_dim) + self.encoder = Encoder(hparams) + self.decoder = Decoder(hparams) + self.postnet = Postnet(hparams) + + def parse_batch(self, batch): + text_padded, input_lengths, mel_padded, gate_padded, \ + output_lengths = batch + text_padded = to_gpu(text_padded).long() + input_lengths = to_gpu(input_lengths).long() + max_len = torch.max(input_lengths.data) + mel_padded = to_gpu(mel_padded).float() + gate_padded = to_gpu(gate_padded).float() + output_lengths = to_gpu(output_lengths).long() + + return ( + (text_padded, input_lengths, mel_padded, max_len, output_lengths), + (mel_padded, gate_padded)) + + def parse_input(self, inputs): + inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs + return inputs + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask = ~get_mask_from_lengths(output_lengths+1) # +1 token + mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + outputs[0].data.masked_fill_(mask, 0.0) + outputs[1].data.masked_fill_(mask, 0.0) + outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies + + outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs + + return outputs + + def forward(self, inputs): + inputs, input_lengths, targets, max_len, \ + output_lengths = self.parse_input(inputs) + input_lengths, output_lengths = input_lengths.data, output_lengths.data + + embedded_inputs = self.embedding(inputs).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, input_lengths) + + mel_outputs, gate_outputs, alignments = self.decoder( + encoder_outputs, targets, memory_lengths=input_lengths) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + # DataParallel expects equal sized inputs/outputs, hence padding + if input_lengths is not None: + alignments = alignments.unsqueeze(0) + alignments = nn.functional.pad( + alignments, + (0, max_len - alignments.size(3), 0, 0), + "constant", 0) + alignments = alignments.squeeze() + return self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], + output_lengths) + + def inference(self, inputs): + inputs = self.parse_input(inputs) + embedded_inputs = self.embedding(inputs).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + mel_outputs, gate_outputs, alignments = self.decoder.inference( + encoder_outputs) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + outputs = self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) + + return outputs diff --git a/multiproc.py b/multiproc.py new file mode 100644 index 0000000..060ff93 --- /dev/null +++ b/multiproc.py @@ -0,0 +1,23 @@ +import time +import torch +import sys +import subprocess + +argslist = list(sys.argv)[1:] +num_gpus = torch.cuda.device_count() +argslist.append('--n_gpus={}'.format(num_gpus)) +workers = [] +job_id = time.strftime("%Y_%m_%d-%H%M%S") +argslist.append("--group_name=group_{}".format(job_id)) + +for i in range(num_gpus): + argslist.append('--rank={}'.format(i)) + stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), + "w") + print(argslist) + p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) + workers.append(p) + argslist = argslist[:-1] + +for p in workers: + p.wait() diff --git a/plotting_utils.py b/plotting_utils.py new file mode 100644 index 0000000..ca7e168 --- /dev/null +++ b/plotting_utils.py @@ -0,0 +1,61 @@ +import matplotlib +matplotlib.use("Agg") +import matplotlib.pylab as plt +import numpy as np + + +def save_figure_to_numpy(fig): + # save it to a numpy array. + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_alignment_to_numpy(alignment, info=None): + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment, aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def plot_spectrogram_to_numpy(spectrogram): + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): + fig, ax = plt.subplots(figsize=(12, 3)) + ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, + color='green', marker='+', s=1, label='target') + ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, + color='red', marker='.', s=1, label='predicted') + + plt.xlabel("Frames (Green target, Red predicted)") + plt.ylabel("Gate State") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data diff --git a/stft.py b/stft.py new file mode 100644 index 0000000..8e137d3 --- /dev/null +++ b/stft.py @@ -0,0 +1,140 @@ +""" +BSD 3-Clause License + +Copyright (c) 2017, Prem Seetharaman +All rights reserved. + +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from audio_processing import window_sumsquare + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(win_length >= filter_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction diff --git a/train.py b/train.py new file mode 100644 index 0000000..dd5de2e --- /dev/null +++ b/train.py @@ -0,0 +1,272 @@ +import os +import time +import argparse +import math + +import torch +from distributed import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.nn import DataParallel +from torch.utils.data import DataLoader + +from fp16_optimizer import FP16_Optimizer + +from model import Tacotron2 +from data_utils import TextMelLoader, TextMelCollate +from loss_function import Tacotron2Loss +from logger import Tacotron2Logger +from hparams import create_hparams + + +def batchnorm_to_float(module): + """Converts batch norm modules to FP32""" + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.float() + for child in module.children(): + batchnorm_to_float(child) + return module + + +def reduce_tensor(tensor, num_gpus): + rt = tensor.clone() + torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM) + rt /= num_gpus + return rt + + +def init_distributed(hparams, n_gpus, rank, group_name): + assert torch.cuda.is_available(), "Distributed mode requires CUDA." + print("Initializing distributed") + # Set cuda device so everything is done on the right GPU. + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Initialize distributed communication + torch.distributed.init_process_group( + backend=hparams.dist_backend, init_method=hparams.dist_url, + world_size=n_gpus, rank=rank, group_name=group_name) + + print("Done initializing distributed") + + +def prepare_dataloaders(hparams): + # Get data, data loaders and collate function ready + trainset = TextMelLoader(hparams.training_files, hparams) + valset = TextMelLoader(hparams.validation_files, hparams) + collate_fn = TextMelCollate(hparams.n_frames_per_step) + + train_sampler = DistributedSampler(trainset) \ + if hparams.distributed_run else None + + train_loader = DataLoader(trainset, num_workers=1, shuffle=False, + sampler=train_sampler, + batch_size=hparams.batch_size, pin_memory=False, + drop_last=True, collate_fn=collate_fn) + return train_loader, valset, collate_fn + + +def prepare_directories_and_logger(output_directory, log_directory, rank): + if rank == 0: + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + os.chmod(output_directory, 0o775) + logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) + else: + logger = None + return logger + + +def load_model(hparams): + model = Tacotron2(hparams).cuda() + model = batchnorm_to_float(model.half()) if hparams.fp16_run else model + model = DistributedDataParallel(model) \ + if hparams.distributed_run else DataParallel(model) + return model + + +def warm_start_model(checkpoint_path, model): + assert os.path.isfile(checkpoint_path) + print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) + checkpoint_dict = torch.load(checkpoint_path) + model.load_state_dict(checkpoint_dict['state_dict']) + return model + + +def load_checkpoint(checkpoint_path, model, optimizer): + assert os.path.isfile(checkpoint_path) + print("Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint_dict['state_dict']) + optimizer.load_state_dict(checkpoint_dict['optimizer']) + learning_rate = checkpoint_dict['learning_rate'] + iteration = checkpoint_dict['iteration'] + print("Loaded checkpoint '{}' from iteration {}" .format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): + print("Saving model and optimizer state at iteration {} to {}".format( + iteration, filepath)) + torch.save({'iteration': iteration, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'learning_rate': learning_rate}, filepath) + + +def validate(model, criterion, valset, iteration, batch_size, n_gpus, + collate_fn, logger, distributed_run, rank): + """Handles all the validation scoring and printing""" + model.eval() + with torch.no_grad(): + val_sampler = DistributedSampler(valset) if distributed_run else None + val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, + shuffle=False, batch_size=batch_size, + pin_memory=False, collate_fn=collate_fn) + + val_loss = 0.0 + for i, batch in enumerate(val_loader): + x, y = model.module.parse_batch(batch) + y_pred = model(x) + loss = criterion(y_pred, y) + reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ + if distributed_run else loss.data[0] + val_loss += reduced_val_loss + val_loss = val_loss / (i + 1) + + model.train() + return val_loss + + +def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, + rank, group_name, hparams): + """Training and validation logging results to tensorboard and stdout + + Params + ------ + output_directory (string): directory to save checkpoints + log_directory (string) directory to save tensorboard logs + checkpoint_path(string): checkpoint path + n_gpus (int): number of gpus + rank (int): rank of current gpu + hparams (object): comma separated list of "name=value" pairs. + """ + if hparams.distributed_run: + init_distributed(hparams, n_gpus, rank, group_name) + + torch.manual_seed(hparams.seed) + torch.cuda.manual_seed(hparams.seed) + + model = load_model(hparams) + learning_rate = hparams.learning_rate + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, + weight_decay=hparams.weight_decay) + if hparams.fp16_run: + optimizer = FP16_Optimizer( + optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) + + criterion = Tacotron2Loss() + + logger = prepare_directories_and_logger( + output_directory, log_directory, rank) + + train_loader, valset, collate_fn = prepare_dataloaders(hparams) + + # Load checkpoint if one exists + iteration = 0 + epoch_offset = 0 + if checkpoint_path is not None: + if warm_start: + model = warm_start_model(checkpoint_path, model) + else: + model, optimizer, learning_rate, iteration = load_checkpoint( + checkpoint_path, model, optimizer) + iteration += 1 # next iteration is iteration + 1 + epoch_offset = max(0, int(iteration / len(train_loader))) + + model.train() + # ================ MAIN TRAINNIG LOOP! =================== + for epoch in range(epoch_offset, hparams.epochs): + print("Epoch: {}".format(epoch)) + for i, batch in enumerate(train_loader): + start = time.perf_counter() + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + model.zero_grad() + x, y = model.module.parse_batch(batch) + y_pred = model(x) + loss = criterion(y_pred, y) + reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ + if hparams.distributed_run else loss.data[0] + + if hparams.fp16_run: + optimizer.backward(loss) + grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) + else: + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm( + model.module.parameters(), hparams.grad_clip_thresh) + + optimizer.step() + + overflow = optimizer.overflow if hparams.fp16_run else False + + if not overflow and not math.isnan(reduced_loss) and rank == 0: + duration = time.perf_counter() - start + print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( + iteration, reduced_loss, grad_norm, duration)) + + logger.log_training( + reduced_loss, grad_norm, learning_rate, duration, iteration) + + if not overflow and (iteration % hparams.iters_per_checkpoint == 0): + reduced_val_loss = validate( + model, criterion, valset, iteration, hparams.batch_size, + n_gpus, collate_fn, logger, hparams.distributed_run, rank) + + if rank == 0: + print("Validation loss {}: {:9f} ".format( + iteration, reduced_val_loss)) + logger.log_validation( + reduced_val_loss, model, y, y_pred, iteration) + checkpoint_path = os.path.join( + output_directory, "checkpoint_{}".format(iteration)) + save_checkpoint(model, optimizer, learning_rate, iteration, + checkpoint_path) + + iteration += 1 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-o', '--output_directory', type=str, + help='directory to save checkpoints') + parser.add_argument('-l', '--log_directory', type=str, + help='directory to save tensorboard logs') + parser.add_argument('-c', '--checkpoint_path', type=str, default=None, + required=False, help='checkpoint path') + parser.add_argument('--warm_start', action='store_true', + help='load the model only (warm start)') + parser.add_argument('--n_gpus', type=int, default=1, + required=False, help='number of gpus') + parser.add_argument('--rank', type=int, default=0, + required=False, help='rank of current gpu') + parser.add_argument('--group_name', type=str, default='group_name', + required=False, help='Distributed group name') + parser.add_argument('--hparams', type=str, + required=False, help='comma separated name=value pairs') + + args = parser.parse_args() + hparams = create_hparams(args.hparams) + + torch.backends.cudnn.enabled = hparams.cudnn_enabled + torch.backends.cudnn.benchmark = hparams.cudnn_benchmark + + print("FP16 Run:", hparams.fp16_run) + print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling) + print("Distributed Run:", hparams.distributed_run) + print("cuDNN Enabled:", hparams.cudnn_enabled) + print("cuDNN Benchmark:", hparams.cudnn_benchmark) + + train(args.output_directory, args.log_directory, args.checkpoint_path, + args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..01046c3 --- /dev/null +++ b/utils.py @@ -0,0 +1,32 @@ +import numpy as np +from scipy.io.wavfile import read +import torch + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths) + ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda() + mask = (ids < lengths.unsqueeze(1)).byte() + return mask + + +def load_wav_to_torch(full_path, sr): + sampling_rate, data = read(full_path) + assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format( + sr, sampling_rate, full_path) + return torch.FloatTensor(data.astype(np.float32)) + + +def load_filepaths_and_text(filename, sort_by_length, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split) for line in f] + + if sort_by_length: + filepaths_and_text.sort(key=lambda x: len(x[1])) + + return filepaths_and_text + + +def to_gpu(x): + x = x.contiguous().cuda(async=True) + return torch.autograd.Variable(x)