Browse Source

adding python files

master
Rafael Valle 6 years ago
parent
commit
09bbec073d
15 changed files with 2134 additions and 0 deletions
  1. +93
    -0
      audio_processing.py
  2. +101
    -0
      data_utils.py
  3. +120
    -0
      distributed.py
  4. +381
    -0
      fp16_optimizer.py
  5. +91
    -0
      hparams.py
  6. +80
    -0
      layers.py
  7. +48
    -0
      logger.py
  8. +19
    -0
      loss_function.py
  9. +132
    -0
      loss_scaler.py
  10. +541
    -0
      model.py
  11. +23
    -0
      multiproc.py
  12. +61
    -0
      plotting_utils.py
  13. +140
    -0
      stft.py
  14. +272
    -0
      train.py
  15. +32
    -0
      utils.py

+ 93
- 0
audio_processing.py View File

@ -0,0 +1,93 @@
import torch
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
def griffin_lim(magnitudes, stft_fn, n_iters=30):
"""
PARAMS
------
magnitudes: spectrogram magnitudes
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
"""
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
angles = angles.astype(np.float32)
angles = torch.autograd.Variable(torch.from_numpy(angles))
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
for i in range(n_iters):
_, angles = stft_fn.transform(signal)
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
return signal
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C

+ 101
- 0
data_utils.py View File

@ -0,0 +1,101 @@
import random
import torch
import torch.utils.data
import layers
from utils import load_wav_to_torch, load_filepaths_and_text
from text import text_to_sequence
class TextMelLoader(torch.utils.data.Dataset):
"""
1) loads audio,text pairs
2) normalizes text and converts them to sequences of one-hot vectors
3) computes mel-spectrograms from audio files.
"""
def __init__(self, audiopaths_and_text, hparams, shuffle=True):
self.audiopaths_and_text = load_filepaths_and_text(
audiopaths_and_text, hparams.sort_by_length)
self.text_cleaners = hparams.text_cleaners
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
random.seed(1234)
if shuffle:
random.shuffle(self.audiopaths_and_text)
def get_mel_text_pair(self, audiopath_and_text):
# separate filename and text
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
text = self.get_text(text)
mel = self.get_mel(audiopath)
return (text, mel)
def get_mel(self, filename):
audio = load_wav_to_torch(filename, self.sampling_rate)
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
return melspec
def get_text(self, text):
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
return text_norm
def __getitem__(self, index):
return self.get_mel_text_pair(self.audiopaths_and_text[index])
def __len__(self):
return len(self.audiopaths_and_text)
class TextMelCollate():
""" Zero-pads model inputs and targets based on number of frames per setep
"""
def __init__(self, n_frames_per_step):
self.n_frames_per_step = n_frames_per_step
def __call__(self, batch):
"""Collate's training batch from normalized text and mel-spectrogram
PARAMS
------
batch: [text_normalized, mel_normalized]
"""
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]),
dim=0, descending=True)
max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text
# Right zero-pad mel-spec with extra single zero vector to mark the end
num_mels = batch[0][1].size(0)
max_target_len = max([x[1].size(1) for x in batch]) + 1
if max_target_len % self.n_frames_per_step != 0:
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
assert max_target_len % self.n_frames_per_step == 0
# include mel padded and gate padded
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
mel_padded.zero_()
gate_padded = torch.FloatTensor(len(batch), max_target_len)
gate_padded.zero_()
output_lengths = torch.LongTensor(len(batch))
for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]][1]
mel_padded[i, :, :mel.size(1)] = mel
gate_padded[i, mel.size(1):] = 1
output_lengths[i] = mel.size(1)
return text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths

+ 120
- 0
distributed.py View File

@ -0,0 +1,120 @@
import torch
import torch.distributed as dist
from torch.nn.modules import Module
def _flatten_dense_tensors(tensors):
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
same dense type.
Since inputs are dense, the resulting tensor will be a concatenated 1D
buffer. Element-wise operation on this buffer will be equivalent to
operating individually.
Arguments:
tensors (Iterable[Tensor]): dense tensors to flatten.
Returns:
A contiguous 1D buffer containing input tensors.
"""
if len(tensors) == 1:
return tensors[0].contiguous().view(-1)
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
return flat
def _unflatten_dense_tensors(flat, tensors):
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
same dense type, and that flat is given by _flatten_dense_tensors.
Arguments:
flat (Tensor): flattened dense tensors to unflatten.
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
unflatten flat.
Returns:
Unflattened dense tensors with sizes same as tensors and values from
flat.
"""
outputs = []
offset = 0
for tensor in tensors:
numel = tensor.numel()
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
offset += numel
return tuple(outputs)
'''
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
launcher included with this example. It assumes that your run is using multiprocess with 1
GPU/process, that the model is on the correct device, and that torch.set_device has been
used to set the device.
Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
and will be allreduced at the finish of the backward pass.
'''
class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
#fallback for PyTorch 0.3
if not hasattr(dist, '_backend'):
self.warn_on_half = True
else:
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.module = module
for p in self.module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)
def allreduce_params():
if(self.needs_reduction):
self.needs_reduction = False
buckets = {}
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case. This currently requires" +
"PyTorch built from top of tree master.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
param._execution_engine.queue_callback(allreduce_params)
if param.requires_grad:
param.register_hook(allreduce_hook)
def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''

+ 381
- 0
fp16_optimizer.py View File

@ -0,0 +1,381 @@
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from loss_scaler import DynamicLossScaler, LossScaler
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
class FP16_Optimizer(object):
"""
FP16_Optimizer is designed to wrap an existing PyTorch optimizer,
and enable an fp16 model to be trained using a master copy of fp32 weights.
Args:
optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, FP16_Optimizer replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any static_loss_scale option.
"""
def __init__(self, optimizer, static_loss_scale=1.0, dynamic_loss_scale=False):
if not torch.cuda.is_available:
raise SystemError('Cannot use fp16 without CUDA')
self.fp16_param_groups = []
self.fp32_param_groups = []
self.fp32_flattened_groups = []
for i, param_group in enumerate(optimizer.param_groups):
print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
for param in param_group['params']:
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
fp16_params_this_group.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
fp32_flattened_this_group = None
if len(fp16_params_this_group) > 0:
fp32_flattened_this_group = _flatten_dense_tensors(
[param.detach().data.clone().float() for param in fp16_params_this_group])
fp32_flattened_this_group = Variable(fp32_flattened_this_group, requires_grad = True)
fp32_flattened_this_group.grad = fp32_flattened_this_group.new(
*fp32_flattened_this_group.size())
# python's lovely list concatenation via +
if fp32_flattened_this_group is not None:
param_group['params'] = [fp32_flattened_this_group] + fp32_params_this_group
else:
param_group['params'] = fp32_params_this_group
self.fp16_param_groups.append(fp16_params_this_group)
self.fp32_param_groups.append(fp32_params_this_group)
self.fp32_flattened_groups.append(fp32_flattened_this_group)
# print("self.fp32_flattened_groups = ", self.fp32_flattened_groups)
# print("self.fp16_param_groups = ", self.fp16_param_groups)
self.optimizer = optimizer.__class__(optimizer.param_groups)
# self.optimizer.load_state_dict(optimizer.state_dict())
self.param_groups = self.optimizer.param_groups
if dynamic_loss_scale:
self.dynamic_loss_scale = True
self.loss_scaler = DynamicLossScaler()
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale)
self.overflow = False
self.first_closure_call_this_step = True
def zero_grad(self):
"""
Zero fp32 and fp16 parameter grads.
"""
self.optimizer.zero_grad()
for fp16_group in self.fp16_param_groups:
for param in fp16_group:
if param.grad is not None:
param.grad.detach_() # This does appear in torch.optim.optimizer.zero_grad(),
# but I'm not sure why it's needed.
param.grad.zero_()
def _check_overflow(self):
params = []
for group in self.fp16_param_groups:
for param in group:
params.append(param)
for group in self.fp32_param_groups:
for param in group:
params.append(param)
self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
def _copy_grads_fp16_to_fp32(self):
for fp32_group, fp16_group in zip(self.fp32_flattened_groups, self.fp16_param_groups):
if len(fp16_group) > 0:
# This might incur one more deep copy than is necessary.
fp32_group.grad.data.copy_(
_flatten_dense_tensors([fp16_param.grad.data for fp16_param in fp16_group]))
def _downscale_fp32(self):
if self.loss_scale != 1.0:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad.data.mul_(1./self.loss_scale)
def clip_fp32_grads(self, clip=-1):
if not self.overflow:
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
if clip > 0:
return torch.nn.utils.clip_grad_norm(fp32_params, clip)
def _copy_params_fp32_to_fp16(self):
for fp16_group, fp32_group in zip(self.fp16_param_groups, self.fp32_flattened_groups):
if len(fp16_group) > 0:
for fp16_param, fp32_data in zip(fp16_group,
_unflatten_dense_tensors(fp32_group.data, fp16_group)):
fp16_param.data.copy_(fp32_data)
def state_dict(self):
"""
Returns a dict containing the current state of this FP16_Optimizer instance.
This dict contains attributes of FP16_Optimizer, as well as the state_dict
of the contained Pytorch optimizer.
Untested.
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict.
Untested.
"""
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
def step(self, closure=None): # could add clip option.
"""
If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss).
step updates the fp32 master copy of parameters using the optimizer supplied to
FP16_Optimizer's constructor, then copies the updated fp32 params into the fp16 params
originally referenced by Fp16_Optimizer's constructor, so the user may immediately run
another forward pass using their model.
If a closure is supplied, step may be called without a prior call to self.backward(loss).
However, the user should take care that any loss.backward() call within the closure
has been replaced by fp16_optimizer_obj.backward(loss).
Args:
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer's constructor. closure should call zero_grad on the FP16_Optimizer object, compute the loss, call .backward(loss), and return the loss.
Closure example::
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
optimizer.backward(loss)
return loss
optimizer.step(closure)
.. note::
The only changes that need to be made compared to
`ordinary optimizer closures`_ are that "optimizer" itself should be an instance of
FP16_Optimizer, and that the call to loss.backward should be replaced by
optimizer.backward(loss).
.. warning::
Currently, calling step with a closure is not compatible with dynamic loss scaling.
.. _`ordinary optimizer closures`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""
if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler):
raise TypeError("Using step with a closure is currently not "
"compatible with dynamic loss scaling.")
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
print("OVERFLOW! Skipping step. Attempted loss scale: {}".format(scale))
return
if closure is not None:
self._step_with_closure(closure)
else:
self.optimizer.step()
self._copy_params_fp32_to_fp16()
return
def _step_with_closure(self, closure):
def wrapped_closure():
if self.first_closure_call_this_step:
"""
We expect that the fp16 params are initially fresh on entering self.step(),
so _copy_params_fp32_to_fp16() is unnecessary the first time wrapped_closure()
is called within self.optimizer.step().
"""
self.first_closure_call_this_step = False
else:
"""
If self.optimizer.step() internally calls wrapped_closure more than once,
it may update the fp32 params after each call. However, self.optimizer
doesn't know about the fp16 params at all. If the fp32 params get updated,
we can't rely on self.optimizer to refresh the fp16 params. We need
to handle that manually:
"""
self._copy_params_fp32_to_fp16()
"""
Our API expects the user to give us ownership of the backward() call by
replacing all calls to loss.backward() with optimizer.backward(loss).
This requirement holds whether or not the call to backward() is made within
a closure.
If the user is properly calling optimizer.backward(loss) within "closure,"
calling closure() here will give the fp32 master params fresh gradients
for the optimizer to play with,
so all wrapped_closure needs to do is call closure() and return the loss.
"""
temp_loss = closure()
return temp_loss
self.optimizer.step(wrapped_closure)
self.first_closure_call_this_step = True
def backward(self, loss, update_fp32_grads=True):
"""
fp16_optimizer_obj.backward performs the following conceptual operations:
fp32_loss = loss.float() (see first Note below)
scaled_loss = fp32_loss*loss_scale
scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the
fp16 model's leaves.
fp16 grads are then copied to the stored fp32 params' .grad attributes (see second Note).
Finally, fp32 grads are divided by loss_scale.
In this way, after fp16_optimizer_obj.backward, the fp32 parameters have fresh gradients,
and fp16_optimizer_obj.step may be called.
.. note::
Converting the loss to fp32 before applying the loss scale provides some
additional safety against overflow if the user has supplied an fp16 value.
However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
fp16_optimizer_obj.backward.
.. note::
The gradients found in an fp16 model's leaves after a call to
fp16_optimizer_obj.backward should not be regarded as valid in general,
because it's possible
they have been scaled (and in the case of dynamic loss scaling,
the scale factor may silently change over time).
If the user wants to inspect gradients after a call to fp16_optimizer_obj.backward,
he/she should query the .grad attribute of FP16_Optimizer's stored fp32 parameters.
Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_fp32_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay this copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling fp16_optimizer_obj.update_fp32_grads before calling fp16_optimizer_obj.step.
Example::
# Ordinary operation:
optimizer.backward(loss)
# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)
# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_fp32_grads=False)
optimizer.backward(loss2, update_fp32_grads=False)
optimizer.update_fp32_grads()
"""
self.loss_scaler.backward(loss.float())
if update_fp32_grads:
self.update_fp32_grads()
def update_fp32_grads(self):
"""
Copy the .grad attribute from stored references to fp16 parameters to
the .grad attribute of the master fp32 parameters that are directly
updated by the optimizer. :attr:`update_fp32_grads` only needs to be called if
fp16_optimizer_obj.backward was called with update_fp32_grads=False.
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
self._copy_grads_fp16_to_fp32()
self._downscale_fp32()
@property
def loss_scale(self):
return self.loss_scaler.loss_scale

+ 91
- 0
hparams.py View File

@ -0,0 +1,91 @@
import tensorflow as tf
from text import symbols
def create_hparams(hparams_string=None, verbose=False):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = tf.contrib.training.HParams(
################################
# Experiment Parameters #
################################
epochs=500,
iters_per_checkpoint=500,
seed=1234,
dynamic_loss_scaling=True,
fp16_run=False,
distributed_run=False,
dist_backend="nccl",
dist_url="file://distributed.dpt",
cudnn_enabled=True,
cudnn_benchmark=False,
################################
# Data Parameters #
################################
training_files='ljs_audio_text_train_filelist.txt',
validation_files='ljs_audio_text_val_filelist.txt',
text_cleaners=['english_cleaners'],
sort_by_length=False,
################################
# Audio Parameters #
################################
max_wav_value=32768.0,
sampling_rate=22050,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None, # if None, half the sampling rate
################################
# Model Parameters #
################################
n_symbols=len(symbols),
symbols_embedding_dim=512,
# Encoder parameters
encoder_kernel_size=5,
encoder_n_convolutions=3,
encoder_embedding_dim=512,
# Decoder parameters
n_frames_per_step=1,
decoder_rnn_dim=1024,
prenet_dim=256,
max_decoder_steps=1000,
gate_threshold=0.6,
# Attention parameters
attention_rnn_dim=1024,
attention_dim=128,
# Location Layer parameters
attention_location_n_filters=32,
attention_location_kernel_size=31,
# Mel-post processing network parameters
postnet_embedding_dim=512,
postnet_kernel_size=5,
postnet_n_convolutions=5,
################################
# Optimization Hyperparameters #
################################
learning_rate=1e-3,
weight_decay=1e-6,
grad_clip_thresh=1,
batch_size=48,
mask_padding=False # set model's padded outputs to padded values
)
if hparams_string:
tf.logging.info('Parsing command line hparams: %s', hparams_string)
hparams.parse(hparams_string)
if verbose:
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams

+ 80
- 0
layers.py View File

@ -0,0 +1,80 @@
import torch
from librosa.filters import mel as librosa_mel_fn
from audio_processing import dynamic_range_compression
from audio_processing import dynamic_range_decompression
from stft import STFT
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=None):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer('mel_basis', mel_basis)
def spectral_normalize(self, magnitudes):
output = dynamic_range_compression(magnitudes)
return output
def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output
def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert(torch.min(y.data) >= -1)
assert(torch.max(y.data) <= 1)
magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output

+ 48
- 0
logger.py View File

@ -0,0 +1,48 @@
import random
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy
from plotting_utils import plot_gate_outputs_to_numpy
class Tacotron2Logger(SummaryWriter):
def __init__(self, logdir):
super(Tacotron2Logger, self).__init__(logdir)
def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
iteration):
self.add_scalar("training.loss", reduced_loss, iteration)
self.add_scalar("grad.norm", grad_norm, iteration)
self.add_scalar("learning.rate", learning_rate, iteration)
self.add_scalar("duration", duration, iteration)
def log_validation(self, reduced_loss, model, y, y_pred, iteration):
self.add_scalar("validation.loss", reduced_loss, iteration)
_, mel_outputs, gate_outputs, alignments = y_pred
mel_targets, gate_targets = y
# plot distribution of parameters
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
self.add_histogram(tag, value.data.cpu().numpy(), iteration)
# plot alignment, mel target and predicted, gate target and predicted
idx = random.randint(0, alignments.size(0) - 1)
self.add_image(
"alignment",
plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
iteration)
self.add_image(
"mel_target",
plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
iteration)
self.add_image(
"mel_predicted",
plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
iteration)
self.add_image(
"gate",
plot_gate_outputs_to_numpy(
gate_targets[idx].data.cpu().numpy(),
F.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
iteration)

+ 19
- 0
loss_function.py View File

@ -0,0 +1,19 @@
from torch import nn
class Tacotron2Loss(nn.Module):
def __init__(self):
super(Tacotron2Loss, self).__init__()
def forward(self, model_output, targets):
mel_target, gate_target = targets[0], targets[1]
mel_target.requires_grad = False
gate_target.requires_grad = False
gate_target = gate_target.view(-1, 1)
mel_out, mel_out_postnet, gate_out, _ = model_output
gate_out = gate_out.view(-1, 1)
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
return mel_loss + gate_loss

+ 132
- 0
loss_scaler.py View File

@ -0,0 +1,132 @@
import torch
class LossScaler:
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
class DynamicLossScaler:
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
# return False
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
inf_count = torch.sum(x.abs() == float('inf'))
if inf_count > 0:
return True
nan_count = torch.sum(x != x)
return nan_count > 0
# `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow):
if overflow:
#self.cur_scale /= self.scale_factor
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
# self.cur_scale = 1
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)

+ 541
- 0
model.py View File

@ -0,0 +1,541 @@
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from layers import ConvNorm, LinearNorm
from utils import to_gpu, get_mask_from_lengths
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size,
attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(2, attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding, bias=False, stride=1,
dilation=1)
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
bias=False, w_init_gain='tanh')
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
super(Attention, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
w_init_gain='tanh')
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory,
attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(F.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, memory, processed_memory,
attention_weights_cat, mask):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class Prenet(nn.Module):
def __init__(self, in_dim, sizes):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[LinearNorm(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, sizes)])
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class Postnet(nn.Module):
"""Postnet
- Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(self, hparams):
super(Postnet, self).__init__()
self.dropout = nn.Dropout(0.5)
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.postnet_embedding_dim))
)
for i in range(1, hparams.postnet_n_convolutions - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.postnet_embedding_dim,
hparams.postnet_embedding_dim,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.postnet_embedding_dim))
)
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(hparams.n_mel_channels))
)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = self.dropout(F.tanh(self.convolutions[i](x)))
x = self.dropout(self.convolutions[-1](x))
return x
class Encoder(nn.Module):
"""Encoder module:
- Three 1-d convolution banks
- Bidirectional LSTM
"""
def __init__(self, hparams):
super(Encoder, self).__init__()
self.dropout = nn.Dropout(0.5)
convolutions = []
for _ in range(hparams.encoder_n_convolutions):
conv_layer = nn.Sequential(
ConvNorm(hparams.encoder_embedding_dim,
hparams.encoder_embedding_dim,
kernel_size=hparams.encoder_kernel_size, stride=1,
padding=int((hparams.encoder_kernel_size - 1) / 2),
dilation=1, w_init_gain='relu'),
nn.BatchNorm1d(hparams.encoder_embedding_dim))
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
int(hparams.encoder_embedding_dim / 2), 1,
batch_first=True, bidirectional=True)
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = self.dropout(F.relu(conv(x)))
x = x.transpose(1, 2)
# pytorch tensor are not reversible, hence the conversion
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
def inference(self, x):
for conv in self.convolutions:
x = self.dropout(F.relu(conv(x)))
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
class Decoder(nn.Module):
def __init__(self, hparams):
super(Decoder, self).__init__()
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.encoder_embedding_dim = hparams.encoder_embedding_dim
self.attention_rnn_dim = hparams.attention_rnn_dim
self.decoder_rnn_dim = hparams.decoder_rnn_dim
self.prenet_dim = hparams.prenet_dim
self.max_decoder_steps = hparams.max_decoder_steps
self.gate_threshold = hparams.gate_threshold
self.prenet = Prenet(
hparams.n_mel_channels * hparams.n_frames_per_step,
[hparams.prenet_dim, hparams.prenet_dim])
self.attention_rnn = nn.LSTMCell(
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim)
self.attention_layer = Attention(
hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
hparams.attention_dim, hparams.attention_location_n_filters,
hparams.attention_location_kernel_size)
self.decoder_rnn = nn.LSTMCell(
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim, 1)
self.linear_projection = LinearNorm(
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
hparams.n_mel_channels*hparams.n_frames_per_step)
self.gate_layer = LinearNorm(
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
bias=True, w_init_gain='sigmoid')
def get_go_frame(self, memory):
""" Gets all zeros frames to use as first decoder input
PARAMS
------
memory: decoder outputs
RETURNS
-------
decoder_input: all zeros frames
"""
B = memory.size(0)
decoder_input = Variable(memory.data.new(
B, self.n_mel_channels * self.n_frames_per_step).zero_())
return decoder_input
def initialize_decoder_states(self, memory, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.encoder_embedding_dim).zero_())
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
def parse_decoder_inputs(self, decoder_inputs):
""" Prepares decoder inputs, i.e. mel outputs
PARAMS
------
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
RETURNS
-------
inputs: processed decoder inputs
"""
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(1, 2)
decoder_inputs = decoder_inputs.view(
decoder_inputs.size(0),
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(0, 1)
return decoder_inputs
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
""" Prepares decoder outputs for output
PARAMS
------
mel_outputs:
gate_outputs: gate output energies
alignments:
RETURNS
-------
mel_outputs:
gate_outpust: gate output energies
alignments:
"""
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments).transpose(0, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
# decouple frames per step
mel_outputs = mel_outputs.view(
mel_outputs.size(0), -1, self.n_mel_channels)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose(1, 2)
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input):
""" Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
decoder_input = self.prenet(decoder_input)
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def forward(self, memory, decoder_inputs, memory_lengths):
""" Decoder forward pass for training
PARAMS
------
memory: Encoder outputs
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
memory_lengths: Encoder output lengths for attention masking.
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory)
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths))
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0):
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
decoder_input = decoder_inputs[len(mel_outputs) - 1]
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
def inference(self, memory):
""" Decoder inference
PARAMS
------
memory: Encoder outputs
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, mask=None)
mel_outputs, gate_outputs, alignments = [], [], []
while True:
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [alignment]
if F.sigmoid(gate_output.data) > self.gate_threshold:
break
elif len(mel_outputs) == self.max_decoder_steps:
print("Warning! Reached max decoder steps")
break
decoder_input = mel_output
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
class Tacotron2(nn.Module):
def __init__(self, hparams):
super(Tacotron2, self).__init__()
self.mask_padding = hparams.mask_padding
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
self.encoder = Encoder(hparams)
self.decoder = Decoder(hparams)
self.postnet = Postnet(hparams)
def parse_batch(self, batch):
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch
text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long()
max_len = torch.max(input_lengths.data)
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
return (
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
(mel_padded, gate_padded))
def parse_input(self, inputs):
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
return inputs
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
return outputs
def forward(self, inputs):
inputs, input_lengths, targets, max_len, \
output_lengths = self.parse_input(inputs)
input_lengths, output_lengths = input_lengths.data, output_lengths.data
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, targets, memory_lengths=input_lengths)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
# DataParallel expects equal sized inputs/outputs, hence padding
if input_lengths is not None:
alignments = alignments.unsqueeze(0)
alignments = nn.functional.pad(
alignments,
(0, max_len - alignments.size(3), 0, 0),
"constant", 0)
alignments = alignments.squeeze()
return self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
output_lengths)
def inference(self, inputs):
inputs = self.parse_input(inputs)
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
mel_outputs, gate_outputs, alignments = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
outputs = self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
return outputs

+ 23
- 0
multiproc.py View File

@ -0,0 +1,23 @@
import time
import torch
import sys
import subprocess
argslist = list(sys.argv)[1:]
num_gpus = torch.cuda.device_count()
argslist.append('--n_gpus={}'.format(num_gpus))
workers = []
job_id = time.strftime("%Y_%m_%d-%H%M%S")
argslist.append("--group_name=group_{}".format(job_id))
for i in range(num_gpus):
argslist.append('--rank={}'.format(i))
stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i),
"w")
print(argslist)
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
workers.append(p)
argslist = argslist[:-1]
for p in workers:
p.wait()

+ 61
- 0
plotting_utils.py View File

@ -0,0 +1,61 @@
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import numpy as np
def save_figure_to_numpy(fig):
# save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment, aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
fig, ax = plt.subplots(figsize=(12, 3))
ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
color='green', marker='+', s=1, label='target')
ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
color='red', marker='.', s=1, label='predicted')
plt.xlabel("Frames (Green target, Red predicted)")
plt.ylabel("Gate State")
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data

+ 140
- 0
stft.py View File

@ -0,0 +1,140 @@
"""
BSD 3-Clause License
Copyright (c) 2017, Prem Seetharaman
All rights reserved.
* Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny
from audio_processing import window_sumsquare
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann'):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(win_length >= filter_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode='reflect')
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction

+ 272
- 0
train.py View File

@ -0,0 +1,272 @@
import os
import time
import argparse
import math
import torch
from distributed import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from fp16_optimizer import FP16_Optimizer
from model import Tacotron2
from data_utils import TextMelLoader, TextMelCollate
from loss_function import Tacotron2Loss
from logger import Tacotron2Logger
from hparams import create_hparams
def batchnorm_to_float(module):
"""Converts batch norm modules to FP32"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
batchnorm_to_float(child)
return module
def reduce_tensor(tensor, num_gpus):
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM)
rt /= num_gpus
return rt
def init_distributed(hparams, n_gpus, rank, group_name):
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
print("Initializing distributed")
# Set cuda device so everything is done on the right GPU.
torch.cuda.set_device(rank % torch.cuda.device_count())
# Initialize distributed communication
torch.distributed.init_process_group(
backend=hparams.dist_backend, init_method=hparams.dist_url,
world_size=n_gpus, rank=rank, group_name=group_name)
print("Done initializing distributed")
def prepare_dataloaders(hparams):
# Get data, data loaders and collate function ready
trainset = TextMelLoader(hparams.training_files, hparams)
valset = TextMelLoader(hparams.validation_files, hparams)
collate_fn = TextMelCollate(hparams.n_frames_per_step)
train_sampler = DistributedSampler(trainset) \
if hparams.distributed_run else None
train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
sampler=train_sampler,
batch_size=hparams.batch_size, pin_memory=False,
drop_last=True, collate_fn=collate_fn)
return train_loader, valset, collate_fn
def prepare_directories_and_logger(output_directory, log_directory, rank):
if rank == 0:
if not os.path.isdir(output_directory):
os.makedirs(output_directory)
os.chmod(output_directory, 0o775)
logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
else:
logger = None
return logger
def load_model(hparams):
model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
model = DistributedDataParallel(model) \
if hparams.distributed_run else DataParallel(model)
return model
def warm_start_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path)
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
checkpoint_dict = torch.load(checkpoint_path)
model.load_state_dict(checkpoint_dict['state_dict'])
return model
def load_checkpoint(checkpoint_path, model, optimizer):
assert os.path.isfile(checkpoint_path)
print("Loading checkpoint '{}'".format(checkpoint_path))
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint_dict['state_dict'])
optimizer.load_state_dict(checkpoint_dict['optimizer'])
learning_rate = checkpoint_dict['learning_rate']
iteration = checkpoint_dict['iteration']
print("Loaded checkpoint '{}' from iteration {}" .format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
print("Saving model and optimizer state at iteration {} to {}".format(
iteration, filepath))
torch.save({'iteration': iteration,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, filepath)
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
collate_fn, logger, distributed_run, rank):
"""Handles all the validation scoring and printing"""
model.eval()
with torch.no_grad():
val_sampler = DistributedSampler(valset) if distributed_run else None
val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
shuffle=False, batch_size=batch_size,
pin_memory=False, collate_fn=collate_fn)
val_loss = 0.0
for i, batch in enumerate(val_loader):
x, y = model.module.parse_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
if distributed_run else loss.data[0]
val_loss += reduced_val_loss
val_loss = val_loss / (i + 1)
model.train()
return val_loss
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
rank, group_name, hparams):
"""Training and validation logging results to tensorboard and stdout
Params
------
output_directory (string): directory to save checkpoints
log_directory (string) directory to save tensorboard logs
checkpoint_path(string): checkpoint path
n_gpus (int): number of gpus
rank (int): rank of current gpu
hparams (object): comma separated list of "name=value" pairs.
"""
if hparams.distributed_run:
init_distributed(hparams, n_gpus, rank, group_name)
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)
model = load_model(hparams)
learning_rate = hparams.learning_rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
weight_decay=hparams.weight_decay)
if hparams.fp16_run:
optimizer = FP16_Optimizer(
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
criterion = Tacotron2Loss()
logger = prepare_directories_and_logger(
output_directory, log_directory, rank)
train_loader, valset, collate_fn = prepare_dataloaders(hparams)
# Load checkpoint if one exists
iteration = 0
epoch_offset = 0
if checkpoint_path is not None:
if warm_start:
model = warm_start_model(checkpoint_path, model)
else:
model, optimizer, learning_rate, iteration = load_checkpoint(
checkpoint_path, model, optimizer)
iteration += 1 # next iteration is iteration + 1
epoch_offset = max(0, int(iteration / len(train_loader)))
model.train()
# ================ MAIN TRAINNIG LOOP! ===================
for epoch in range(epoch_offset, hparams.epochs):
print("Epoch: {}".format(epoch))
for i, batch in enumerate(train_loader):
start = time.perf_counter()
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
model.zero_grad()
x, y = model.module.parse_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
if hparams.distributed_run else loss.data[0]
if hparams.fp16_run:
optimizer.backward(loss)
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm(
model.module.parameters(), hparams.grad_clip_thresh)
optimizer.step()
overflow = optimizer.overflow if hparams.fp16_run else False
if not overflow and not math.isnan(reduced_loss) and rank == 0:
duration = time.perf_counter() - start
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
iteration, reduced_loss, grad_norm, duration))
logger.log_training(
reduced_loss, grad_norm, learning_rate, duration, iteration)
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
reduced_val_loss = validate(
model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
if rank == 0:
print("Validation loss {}: {:9f} ".format(
iteration, reduced_val_loss))
logger.log_validation(
reduced_val_loss, model, y, y_pred, iteration)
checkpoint_path = os.path.join(
output_directory, "checkpoint_{}".format(iteration))
save_checkpoint(model, optimizer, learning_rate, iteration,
checkpoint_path)
iteration += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output_directory', type=str,
help='directory to save checkpoints')
parser.add_argument('-l', '--log_directory', type=str,
help='directory to save tensorboard logs')
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
required=False, help='checkpoint path')
parser.add_argument('--warm_start', action='store_true',
help='load the model only (warm start)')
parser.add_argument('--n_gpus', type=int, default=1,
required=False, help='number of gpus')
parser.add_argument('--rank', type=int, default=0,
required=False, help='rank of current gpu')
parser.add_argument('--group_name', type=str, default='group_name',
required=False, help='Distributed group name')
parser.add_argument('--hparams', type=str,
required=False, help='comma separated name=value pairs')
args = parser.parse_args()
hparams = create_hparams(args.hparams)
torch.backends.cudnn.enabled = hparams.cudnn_enabled
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
print("FP16 Run:", hparams.fp16_run)
print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
print("Distributed Run:", hparams.distributed_run)
print("cuDNN Enabled:", hparams.cudnn_enabled)
print("cuDNN Benchmark:", hparams.cudnn_benchmark)
train(args.output_directory, args.log_directory, args.checkpoint_path,
args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)

+ 32
- 0
utils.py View File

@ -0,0 +1,32 @@
import numpy as np
from scipy.io.wavfile import read
import torch
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths)
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda()
mask = (ids < lengths.unsqueeze(1)).byte()
return mask
def load_wav_to_torch(full_path, sr):
sampling_rate, data = read(full_path)
assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format(
sr, sampling_rate, full_path)
return torch.FloatTensor(data.astype(np.float32))
def load_filepaths_and_text(filename, sort_by_length, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
if sort_by_length:
filepaths_and_text.sort(key=lambda x: len(x[1]))
return filepaths_and_text
def to_gpu(x):
x = x.contiguous().cuda(async=True)
return torch.autograd.Variable(x)

Loading…
Cancel
Save