Browse Source

train.py single gpu and 0.4 update

master
Raul Puri 6 years ago
committed by GitHub
parent
commit
9343f34b0b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 11 deletions
  1. +14
    -11
      train.py

+ 14
- 11
train.py View File

@ -74,14 +74,17 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
logger = None
return logger
def load_model(hparams):
model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
model = DistributedDataParallel(model) \
if hparams.distributed_run else DataParallel(model)
return model
tacotron_model = model
if hparams.distributed_run:
model = DistributedDataParallel(model)
elif torch.cuda.device_count() > 1:
model = DataParallel(model)
return model, tacotron
def warm_start_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path)
@ -114,7 +117,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
collate_fn, logger, distributed_run, rank):
collate_fn, logger, distributed_run, rank, batch_parser):
"""Handles all the validation scoring and printing"""
model.eval()
with torch.no_grad():
@ -125,7 +128,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
val_loss = 0.0
for i, batch in enumerate(val_loader):
x, y = model.module.parse_batch(batch)
x, y = batch_parser(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
@ -193,11 +196,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
param_group['lr'] = learning_rate
model.zero_grad()
x, y = model.module.parse_batch(batch)
x, y = tacotron_model.parse_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y)
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
if hparams.distributed_run else loss.data[0]
reduced_loss = reduce_tensor(loss.data, n_gpus).item() \
if hparams.distributed_run else loss.item()
if hparams.fp16_run:
optimizer.backward(loss)
@ -205,7 +208,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm(
model.module.parameters(), hparams.grad_clip_thresh)
tacotron_model.parameters(), hparams.grad_clip_thresh)
optimizer.step()
@ -222,7 +225,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
reduced_val_loss = validate(
model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
n_gpus, collate_fn, logger, hparams.distributed_run, rank, tacotron_model.parse_batch)
if rank == 0:
print("Validation loss {}: {:9f} ".format(

Loading…
Cancel
Save