|
@ -5,9 +5,9 @@ import math |
|
|
from numpy import finfo |
|
|
from numpy import finfo |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from distributed import DistributedDataParallel |
|
|
|
|
|
|
|
|
from distributed import apply_gradient_allreduce |
|
|
|
|
|
import torch.distributed as dist |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.nn import DataParallel |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
from fp16_optimizer import FP16_Optimizer |
|
|
from fp16_optimizer import FP16_Optimizer |
|
@ -30,19 +30,20 @@ def batchnorm_to_float(module): |
|
|
|
|
|
|
|
|
def reduce_tensor(tensor, num_gpus): |
|
|
def reduce_tensor(tensor, num_gpus): |
|
|
rt = tensor.clone() |
|
|
rt = tensor.clone() |
|
|
torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM) |
|
|
|
|
|
|
|
|
dist.all_reduce(rt, op=dist.reduce_op.SUM) |
|
|
rt /= num_gpus |
|
|
rt /= num_gpus |
|
|
return rt |
|
|
return rt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_distributed(hparams, n_gpus, rank, group_name): |
|
|
def init_distributed(hparams, n_gpus, rank, group_name): |
|
|
assert torch.cuda.is_available(), "Distributed mode requires CUDA." |
|
|
assert torch.cuda.is_available(), "Distributed mode requires CUDA." |
|
|
print("Initializing distributed") |
|
|
|
|
|
|
|
|
print("Initializing Distributed") |
|
|
|
|
|
|
|
|
# Set cuda device so everything is done on the right GPU. |
|
|
# Set cuda device so everything is done on the right GPU. |
|
|
torch.cuda.set_device(rank % torch.cuda.device_count()) |
|
|
torch.cuda.set_device(rank % torch.cuda.device_count()) |
|
|
|
|
|
|
|
|
# Initialize distributed communication |
|
|
# Initialize distributed communication |
|
|
torch.distributed.init_process_group( |
|
|
|
|
|
|
|
|
dist.init_process_group( |
|
|
backend=hparams.dist_backend, init_method=hparams.dist_url, |
|
|
backend=hparams.dist_backend, init_method=hparams.dist_url, |
|
|
world_size=n_gpus, rank=rank, group_name=group_name) |
|
|
world_size=n_gpus, rank=rank, group_name=group_name) |
|
|
|
|
|
|
|
@ -131,22 +132,20 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, |
|
|
pin_memory=False, collate_fn=collate_fn) |
|
|
pin_memory=False, collate_fn=collate_fn) |
|
|
|
|
|
|
|
|
val_loss = 0.0 |
|
|
val_loss = 0.0 |
|
|
if distributed_run or torch.cuda.device_count() > 1: |
|
|
|
|
|
batch_parser = model.module.parse_batch |
|
|
|
|
|
else: |
|
|
|
|
|
batch_parser = model.parse_batch |
|
|
|
|
|
|
|
|
|
|
|
for i, batch in enumerate(val_loader): |
|
|
for i, batch in enumerate(val_loader): |
|
|
x, y = batch_parser(batch) |
|
|
|
|
|
|
|
|
x, y = model.parse_batch(batch) |
|
|
y_pred = model(x) |
|
|
y_pred = model(x) |
|
|
loss = criterion(y_pred, y) |
|
|
loss = criterion(y_pred, y) |
|
|
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ |
|
|
|
|
|
if distributed_run else loss.data[0] |
|
|
|
|
|
|
|
|
if distributed_run: |
|
|
|
|
|
reduced_val_loss = reduce_tensor(loss.data, num_gpus).item() |
|
|
|
|
|
else: |
|
|
|
|
|
reduced_val_loss = loss.item() |
|
|
val_loss += reduced_val_loss |
|
|
val_loss += reduced_val_loss |
|
|
val_loss = val_loss / (i + 1) |
|
|
val_loss = val_loss / (i + 1) |
|
|
|
|
|
|
|
|
model.train() |
|
|
model.train() |
|
|
return val_loss |
|
|
|
|
|
|
|
|
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) |
|
|
|
|
|
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
@ -176,6 +175,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
optimizer = FP16_Optimizer( |
|
|
optimizer = FP16_Optimizer( |
|
|
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) |
|
|
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) |
|
|
|
|
|
|
|
|
|
|
|
if hparams.distributed_run: |
|
|
|
|
|
model = apply_gradient_allreduce(model) |
|
|
|
|
|
|
|
|
criterion = Tacotron2Loss() |
|
|
criterion = Tacotron2Loss() |
|
|
|
|
|
|
|
|
logger = prepare_directories_and_logger( |
|
|
logger = prepare_directories_and_logger( |
|
@ -194,15 +196,10 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
checkpoint_path, model, optimizer) |
|
|
checkpoint_path, model, optimizer) |
|
|
if hparams.use_saved_learning_rate: |
|
|
if hparams.use_saved_learning_rate: |
|
|
learning_rate = _learning_rate |
|
|
learning_rate = _learning_rate |
|
|
|
|
|
|
|
|
iteration += 1 # next iteration is iteration + 1 |
|
|
iteration += 1 # next iteration is iteration + 1 |
|
|
epoch_offset = max(0, int(iteration / len(train_loader))) |
|
|
epoch_offset = max(0, int(iteration / len(train_loader))) |
|
|
|
|
|
|
|
|
model.train() |
|
|
model.train() |
|
|
if hparams.distributed_run or torch.cuda.device_count() > 1: |
|
|
|
|
|
batch_parser = model.module.parse_batch |
|
|
|
|
|
else: |
|
|
|
|
|
batch_parser = model.parse_batch |
|
|
|
|
|
# ================ MAIN TRAINNIG LOOP! =================== |
|
|
# ================ MAIN TRAINNIG LOOP! =================== |
|
|
for epoch in range(epoch_offset, hparams.epochs): |
|
|
for epoch in range(epoch_offset, hparams.epochs): |
|
|
print("Epoch: {}".format(epoch)) |
|
|
print("Epoch: {}".format(epoch)) |
|
@ -212,18 +209,21 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
param_group['lr'] = learning_rate |
|
|
param_group['lr'] = learning_rate |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
model.zero_grad() |
|
|
x, y = batch_parser(batch) |
|
|
|
|
|
|
|
|
x, y = model.parse_batch(batch) |
|
|
y_pred = model(x) |
|
|
y_pred = model(x) |
|
|
|
|
|
|
|
|
loss = criterion(y_pred, y) |
|
|
loss = criterion(y_pred, y) |
|
|
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ |
|
|
|
|
|
if hparams.distributed_run else loss.data[0] |
|
|
|
|
|
|
|
|
if hparams.distributed_run: |
|
|
|
|
|
reduced_loss = reduce_tensor(loss.data, num_gpus).item() |
|
|
|
|
|
else: |
|
|
|
|
|
reduced_loss = loss.item() |
|
|
|
|
|
|
|
|
if hparams.fp16_run: |
|
|
if hparams.fp16_run: |
|
|
optimizer.backward(loss) |
|
|
optimizer.backward(loss) |
|
|
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) |
|
|
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) |
|
|
else: |
|
|
else: |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
grad_norm = torch.nn.utils.clip_grad_norm( |
|
|
|
|
|
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
model.parameters(), hparams.grad_clip_thresh) |
|
|
model.parameters(), hparams.grad_clip_thresh) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
optimizer.step() |
|
@ -234,20 +234,14 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
duration = time.perf_counter() - start |
|
|
duration = time.perf_counter() - start |
|
|
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( |
|
|
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( |
|
|
iteration, reduced_loss, grad_norm, duration)) |
|
|
iteration, reduced_loss, grad_norm, duration)) |
|
|
|
|
|
|
|
|
logger.log_training( |
|
|
logger.log_training( |
|
|
reduced_loss, grad_norm, learning_rate, duration, iteration) |
|
|
reduced_loss, grad_norm, learning_rate, duration, iteration) |
|
|
|
|
|
|
|
|
if not overflow and (iteration % hparams.iters_per_checkpoint == 0): |
|
|
if not overflow and (iteration % hparams.iters_per_checkpoint == 0): |
|
|
reduced_val_loss = validate( |
|
|
|
|
|
model, criterion, valset, iteration, hparams.batch_size, |
|
|
|
|
|
n_gpus, collate_fn, logger, hparams.distributed_run, rank) |
|
|
|
|
|
|
|
|
validate(model, criterion, valset, iteration, hparams.batch_size, |
|
|
|
|
|
n_gpus, collate_fn, logger, hparams.distributed_run, rank) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
if rank == 0: |
|
|
print("Validation loss {}: {:9f} ".format( |
|
|
|
|
|
iteration, reduced_val_loss)) |
|
|
|
|
|
logger.log_validation( |
|
|
|
|
|
reduced_val_loss, model, y, y_pred, iteration) |
|
|
|
|
|
checkpoint_path = os.path.join( |
|
|
checkpoint_path = os.path.join( |
|
|
output_directory, "checkpoint_{}".format(iteration)) |
|
|
output_directory, "checkpoint_{}".format(iteration)) |
|
|
save_checkpoint(model, optimizer, learning_rate, iteration, |
|
|
save_checkpoint(model, optimizer, learning_rate, iteration, |
|
|