Browse Source

Merge pull request #96 from NVIDIA/clean_slate

Clean slate
master
Rafael Valle 6 years ago
committed by GitHub
parent
commit
f02704f338
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 232 additions and 180 deletions
  1. +19
    -14
      README.md
  2. +10
    -9
      data_utils.py
  3. +52
    -0
      distributed.py
  4. +10
    -9
      hparams.py
  5. +60
    -57
      inference.ipynb
  6. +3
    -3
      layers.py
  7. +39
    -38
      model.py
  8. +0
    -1
      requirements.txt
  9. +1
    -1
      stft.py
  10. +0
    -2
      text/__init__.py
  11. +4
    -3
      text/symbols.py
  12. +25
    -31
      train.py
  13. +9
    -12
      utils.py

+ 19
- 14
README.md View File

@ -1,6 +1,6 @@
# Tacotron 2 (without wavenet)
Tacotron 2 PyTorch implementation of [Natural TTS Synthesis By Conditioning
PyTorch implementation of [Natural TTS Synthesis By Conditioning
Wavenet On Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf).
This implementation includes **distributed** and **fp16** support
@ -11,9 +11,7 @@ Distributed and FP16 support relies on work by Christian Sarofeen and NVIDIA's
![Alignment, Predicted Mel Spectrogram, Target Mel Spectrogram](tensorboard.png)
[Download demo audio](https://github.com/NVIDIA/tacotron2/blob/master/demo.wav) trained on LJS and using Ryuchi Yamamoto's [pre-trained Mixture of Logistics
wavenet](https://github.com/r9y9/wavenet_vocoder/)
"Scientists at the CERN laboratory say they have discovered a new particle."
Visit our [website] for audio samples.
## Pre-requisites
1. NVIDIA GPU + CUDA cuDNN
@ -24,11 +22,9 @@ wavenet](https://github.com/r9y9/wavenet_vocoder/)
3. CD into this repo: `cd tacotron2`
4. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt`
- Alternatively, set `load_mel_from_disk=True` in `hparams.py` and update mel-spectrogram paths
5. Install [pytorch 0.4](https://github.com/pytorch/pytorch)
5. Install [PyTorch 1.0]
6. Install python requirements or build docker image
- Install python requirements: `pip install -r requirements.txt`
- **OR**
- Build docker image: `docker build --tag tacotron2 .`
## Training
1. `python train.py --output_directory=outdir --log_directory=logdir`
@ -37,17 +33,22 @@ wavenet](https://github.com/r9y9/wavenet_vocoder/)
## Multi-GPU (distributed) and FP16 Training
1. `python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True`
## Inference
When performing Mel-Spectrogram to Audio synthesis with a WaveNet model, make sure Tacotron 2 and WaveNet were trained on the same mel-spectrogram representation. Follow these steps to use a a simple inference pipeline using griffin-lim:
1. `jupyter notebook --ip=127.0.0.1 --port=31337`
2. load inference.ipynb
## Inference demo
1. Download our published [Tacotron 2] model
2. Download our published [WaveGlow] model
3. `jupyter notebook --ip=127.0.0.1 --port=31337`
4. Load inference.ipynb
N.b. When performing Mel-Spectrogram to Audio synthesis, make sure Tacotron 2
and the Mel decoder were trained on the same mel-spectrogram representation.
## Related repos
[nv-wavenet](https://github.com/NVIDIA/nv-wavenet/): Faster than real-time
wavenet inference
[WaveGlow](https://github.com/NVIDIA/WaveGlow) Faster than real time Flow-based
Generative Network for Speech Synthesis
[nv-wavenet](https://github.com/NVIDIA/nv-wavenet/) Faster than real time
WaveNet.
## Acknowledgements
This implementation uses code from the following repos: [Keith
@ -61,3 +62,7 @@ We are thankful to the Tacotron 2 paper authors, specially Jonathan Shen, Yuxuan
Wang and Zongheng Yang.
[WaveGlow]: https://drive.google.com/file/d/1cjKPHbtAMh_4HTHmuIGNkbOkPBD9qwhj/view?usp=sharing
[Tacotron 2]: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view?usp=sharing
[pytorch 1.0]: https://github.com/pytorch/pytorch#installation
[website]: https://nv-adlr.github.io/WaveGlow

+ 10
- 9
data_utils.py View File

@ -14,9 +14,8 @@ class TextMelLoader(torch.utils.data.Dataset):
2) normalizes text and converts them to sequences of one-hot vectors
3) computes mel-spectrograms from audio files.
"""
def __init__(self, audiopaths_and_text, hparams, shuffle=True):
self.audiopaths_and_text = load_filepaths_and_text(
audiopaths_and_text, hparams.sort_by_length)
def __init__(self, audiopaths_and_text, hparams):
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
self.text_cleaners = hparams.text_cleaners
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
@ -26,8 +25,7 @@ class TextMelLoader(torch.utils.data.Dataset):
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
random.seed(1234)
if shuffle:
random.shuffle(self.audiopaths_and_text)
random.shuffle(self.audiopaths_and_text)
def get_mel_text_pair(self, audiopath_and_text):
# separate filename and text
@ -38,7 +36,10 @@ class TextMelLoader(torch.utils.data.Dataset):
def get_mel(self, filename):
if not self.load_mel_from_disk:
audio = load_wav_to_torch(filename, self.sampling_rate)
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.stft.sampling_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format(
sampling_rate, self.stft.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
@ -87,9 +88,9 @@ class TextMelCollate():
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text
# Right zero-pad mel-spec with extra single zero vector to mark the end
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
max_target_len = max([x[1].size(1) for x in batch]) + 1
max_target_len = max([x[1].size(1) for x in batch])
if max_target_len % self.n_frames_per_step != 0:
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
assert max_target_len % self.n_frames_per_step == 0
@ -103,7 +104,7 @@ class TextMelCollate():
for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]][1]
mel_padded[i, :, :mel.size(1)] = mel
gate_padded[i, mel.size(1):] = 1
gate_padded[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)
return text_padded, input_lengths, mel_padded, gate_padded, \

+ 52
- 0
distributed.py View File

@ -118,3 +118,55 @@ class DistributedDataParallel(Module):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
'''
Modifies existing model to do gradient allreduce, but doesn't change class
so you don't need "module"
'''
def apply_gradient_allreduce(module):
if not hasattr(dist, '_backend'):
module.warn_on_half = True
else:
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)
def allreduce_params():
if(module.needs_reduction):
module.needs_reduction = False
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if module.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case. This currently requires" +
"PyTorch built from top of tree master.")
module.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(module.parameters()):
def allreduce_hook(*unused):
param._execution_engine.queue_callback(allreduce_params)
if param.requires_grad:
param.register_hook(allreduce_hook)
def set_needs_reduction(self, input, output):
self.needs_reduction = True
module.register_forward_hook(set_needs_reduction)
return module

+ 10
- 9
hparams.py View File

@ -10,7 +10,7 @@ def create_hparams(hparams_string=None, verbose=False):
# Experiment Parameters #
################################
epochs=500,
iters_per_checkpoint=500,
iters_per_checkpoint=1000,
seed=1234,
dynamic_loss_scaling=True,
fp16_run=False,
@ -24,10 +24,9 @@ def create_hparams(hparams_string=None, verbose=False):
# Data Parameters #
################################
load_mel_from_disk=False,
training_files='filelists/ljs_audio_text_train_filelist.txt',
validation_files='filelists/ljs_audio_text_val_filelist.txt',
training_files='filelists/ljs_audio22khz_text_train_filelist.txt',
validation_files='filelists/ljs_audio22khz_text_val_filelist.txt',
text_cleaners=['english_cleaners'],
sort_by_length=False,
################################
# Audio Parameters #
@ -39,7 +38,7 @@ def create_hparams(hparams_string=None, verbose=False):
win_length=1024,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None, # if None, half the sampling rate
mel_fmax=8000.0,
################################
# Model Parameters #
@ -57,7 +56,9 @@ def create_hparams(hparams_string=None, verbose=False):
decoder_rnn_dim=1024,
prenet_dim=256,
max_decoder_steps=1000,
gate_threshold=0.6,
gate_threshold=0.5,
p_attention_dropout=0.1,
p_decoder_dropout=0.1,
# Attention parameters
attention_rnn_dim=1024,
@ -78,9 +79,9 @@ def create_hparams(hparams_string=None, verbose=False):
use_saved_learning_rate=False,
learning_rate=1e-3,
weight_decay=1e-6,
grad_clip_thresh=1,
batch_size=48,
mask_padding=False # set model's padded outputs to padded values
grad_clip_thresh=1.0,
batch_size=64,
mask_padding=True # set model's padded outputs to padded values
)
if hparams_string:

+ 60
- 57
inference.ipynb
File diff suppressed because it is too large
View File


+ 3
- 3
layers.py View File

@ -10,7 +10,7 @@ class LinearNorm(torch.nn.Module):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
@ -31,7 +31,7 @@ class ConvNorm(torch.nn.Module):
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
@ -42,7 +42,7 @@ class ConvNorm(torch.nn.Module):
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=None):
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate

+ 39
- 38
model.py View File

@ -1,3 +1,4 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
@ -56,7 +57,7 @@ class Attention(nn.Module):
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(F.tanh(
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
@ -107,7 +108,6 @@ class Postnet(nn.Module):
def __init__(self, hparams):
super(Postnet, self).__init__()
self.dropout = nn.Dropout(0.5)
self.convolutions = nn.ModuleList()
self.convolutions.append(
@ -141,9 +141,8 @@ class Postnet(nn.Module):
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = self.dropout(F.tanh(self.convolutions[i](x)))
x = self.dropout(self.convolutions[-1](x))
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
return x
@ -155,7 +154,6 @@ class Encoder(nn.Module):
"""
def __init__(self, hparams):
super(Encoder, self).__init__()
self.dropout = nn.Dropout(0.5)
convolutions = []
for _ in range(hparams.encoder_n_convolutions):
@ -175,7 +173,7 @@ class Encoder(nn.Module):
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = self.dropout(F.relu(conv(x)))
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
@ -194,7 +192,7 @@ class Encoder(nn.Module):
def inference(self, x):
for conv in self.convolutions:
x = self.dropout(F.relu(conv(x)))
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
@ -215,13 +213,15 @@ class Decoder(nn.Module):
self.prenet_dim = hparams.prenet_dim
self.max_decoder_steps = hparams.max_decoder_steps
self.gate_threshold = hparams.gate_threshold
self.p_attention_dropout = hparams.p_attention_dropout
self.p_decoder_dropout = hparams.p_decoder_dropout
self.prenet = Prenet(
hparams.n_mel_channels * hparams.n_frames_per_step,
[hparams.prenet_dim, hparams.prenet_dim])
self.attention_rnn = nn.LSTMCell(
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim)
self.attention_layer = Attention(
@ -230,12 +230,12 @@ class Decoder(nn.Module):
hparams.attention_location_kernel_size)
self.decoder_rnn = nn.LSTMCell(
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim, 1)
self.linear_projection = LinearNorm(
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
hparams.n_mel_channels*hparams.n_frames_per_step)
hparams.n_mel_channels * hparams.n_frames_per_step)
self.gate_layer = LinearNorm(
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
@ -350,10 +350,13 @@ class Decoder(nn.Module):
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
@ -363,10 +366,14 @@ class Decoder(nn.Module):
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
prenet_output = self.prenet(decoder_input)
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(
self.decoder_cell, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
@ -391,22 +398,23 @@ class Decoder(nn.Module):
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory)
decoder_input = self.get_go_frame(memory).unsqueeze(0)
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)
self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths))
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0):
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
mel_outputs += [mel_output]
gate_outputs += [gate_output.squeeze(1)]
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
decoder_input = decoder_inputs[len(mel_outputs) - 1]
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
@ -430,13 +438,14 @@ class Decoder(nn.Module):
mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output]
gate_outputs += [gate_output.squeeze(1)]
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if F.sigmoid(gate_output.data) > self.gate_threshold:
if torch.sigmoid(gate_output.data) > self.gate_threshold:
break
elif len(mel_outputs) == self.max_decoder_steps:
print("Warning! Reached max decoder steps")
@ -459,8 +468,9 @@ class Tacotron2(nn.Module):
self.n_frames_per_step = hparams.n_frames_per_step
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(hparams)
self.decoder = Decoder(hparams)
self.postnet = Postnet(hparams)
@ -469,8 +479,8 @@ class Tacotron2(nn.Module):
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch
text_padded = to_gpu(text_padded).long()
max_len = int(torch.max(input_lengths.data).numpy())
input_lengths = to_gpu(input_lengths).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
@ -485,7 +495,7 @@ class Tacotron2(nn.Module):
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token
mask = ~get_mask_from_lengths(output_lengths)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
@ -494,7 +504,6 @@ class Tacotron2(nn.Module):
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
return outputs
def forward(self, inputs):
@ -512,14 +521,6 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
# DataParallel expects equal sized inputs/outputs, hence padding
if input_lengths is not None:
alignments = alignments.unsqueeze(0)
alignments = nn.functional.pad(
alignments,
(0, max_len - alignments.size(3), 0, 0),
"constant", 0)
alignments = alignments.squeeze()
return self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
output_lengths)

+ 0
- 1
requirements.txt View File

@ -1,4 +1,3 @@
torch==0.4.0
matplotlib==2.1.0
tensorflow
numpy==1.13.3

+ 1
- 1
stft.py View File

@ -61,7 +61,7 @@ class STFT(torch.nn.Module):
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(win_length >= filter_length)
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)

+ 0
- 2
text/__init__.py View File

@ -37,8 +37,6 @@ def text_to_sequence(text, cleaner_names):
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence

+ 4
- 3
text/symbols.py View File

@ -7,11 +7,12 @@ The default is a set of ASCII characters that works well for English or text tha
from text import cmudict
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet

+ 25
- 31
train.py View File

@ -5,9 +5,9 @@ import math
from numpy import finfo
import torch
from distributed import DistributedDataParallel
from distributed import apply_gradient_allreduce
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from fp16_optimizer import FP16_Optimizer
@ -30,19 +30,20 @@ def batchnorm_to_float(module):
def reduce_tensor(tensor, num_gpus):
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM)
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= num_gpus
return rt
def init_distributed(hparams, n_gpus, rank, group_name):
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
print("Initializing distributed")
print("Initializing Distributed")
# Set cuda device so everything is done on the right GPU.
torch.cuda.set_device(rank % torch.cuda.device_count())
# Initialize distributed communication
torch.distributed.init_process_group(
dist.init_process_group(
backend=hparams.dist_backend, init_method=hparams.dist_url,
world_size=n_gpus, rank=rank, group_name=group_name)
@ -131,22 +132,20 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
pin_memory=False, collate_fn=collate_fn)
val_loss = 0.0
if distributed_run or torch.cuda.device_count() > 1:
batch_parser = model.module.parse_batch
else:
batch_parser = model.parse_batch
for i, batch in enumerate(val_loader):
x, y = batch_parser(batch)
x, y = model.parse_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
if distributed_run else loss.data[0]
if distributed_run:
reduced_val_loss = reduce_tensor(loss.data, num_gpus).item()
else:
reduced_val_loss = loss.item()
val_loss += reduced_val_loss
val_loss = val_loss / (i + 1)
model.train()
return val_loss
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
@ -176,6 +175,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
optimizer = FP16_Optimizer(
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
if hparams.distributed_run:
model = apply_gradient_allreduce(model)
criterion = Tacotron2Loss()
logger = prepare_directories_and_logger(
@ -194,15 +196,10 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
checkpoint_path, model, optimizer)
if hparams.use_saved_learning_rate:
learning_rate = _learning_rate
iteration += 1 # next iteration is iteration + 1
epoch_offset = max(0, int(iteration / len(train_loader)))
model.train()
if hparams.distributed_run or torch.cuda.device_count() > 1:
batch_parser = model.module.parse_batch
else:
batch_parser = model.parse_batch
# ================ MAIN TRAINNIG LOOP! ===================
for epoch in range(epoch_offset, hparams.epochs):
print("Epoch: {}".format(epoch))
@ -212,18 +209,21 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
param_group['lr'] = learning_rate
model.zero_grad()
x, y = batch_parser(batch)
x, y = model.parse_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
if hparams.distributed_run else loss.data[0]
if hparams.distributed_run:
reduced_loss = reduce_tensor(loss.data, num_gpus).item()
else:
reduced_loss = loss.item()
if hparams.fp16_run:
optimizer.backward(loss)
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm(
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), hparams.grad_clip_thresh)
optimizer.step()
@ -234,20 +234,14 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
duration = time.perf_counter() - start
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
iteration, reduced_loss, grad_norm, duration))
logger.log_training(
reduced_loss, grad_norm, learning_rate, duration, iteration)
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
reduced_val_loss = validate(
model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
validate(model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
if rank == 0:
print("Validation loss {}: {:9f} ".format(
iteration, reduced_val_loss))
logger.log_validation(
reduced_val_loss, model, y, y_pred, iteration)
checkpoint_path = os.path.join(
output_directory, "checkpoint_{}".format(iteration))
save_checkpoint(model, optimizer, learning_rate, iteration,

+ 9
- 12
utils.py View File

@ -4,29 +4,26 @@ import torch
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths)
ids = torch.arange(0, max_len).long().cuda()
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1)).byte()
return mask
def load_wav_to_torch(full_path, sr):
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format(
sr, sampling_rate, full_path)
return torch.FloatTensor(data.astype(np.float32))
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, sort_by_length, split="|"):
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
if sort_by_length:
filepaths_and_text.sort(key=lambda x: len(x[1]))
return filepaths_and_text
def to_gpu(x):
x = x.contiguous().cuda(async=True)
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return torch.autograd.Variable(x)

Loading…
Cancel
Save