Browse Source

model.py removing top of three code, cleanup

master
Rafael Valle 6 years ago
parent
commit
646ab0d8c8
1 changed files with 11 additions and 10 deletions
  1. +11
    -10
      train.py

+ 11
- 10
train.py View File

@ -74,22 +74,23 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
logger = None logger = None
return logger return logger
def load_model(hparams): def load_model(hparams):
model = Tacotron2(hparams).cuda() model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
tacotron_model = model
if hparams.distributed_run: if hparams.distributed_run:
model = DistributedDataParallel(model) model = DistributedDataParallel(model)
elif torch.cuda.device_count() > 1: elif torch.cuda.device_count() > 1:
model = DataParallel(model) model = DataParallel(model)
return model, tacotron
return model
def warm_start_model(checkpoint_path, model): def warm_start_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
checkpoint_dict = torch.load(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint_dict['state_dict']) model.load_state_dict(checkpoint_dict['state_dict'])
return model return model
@ -117,7 +118,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
def validate(model, criterion, valset, iteration, batch_size, n_gpus, def validate(model, criterion, valset, iteration, batch_size, n_gpus,
collate_fn, logger, distributed_run, rank, batch_parser):
collate_fn, logger, distributed_run, rank):
"""Handles all the validation scoring and printing""" """Handles all the validation scoring and printing"""
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -128,7 +129,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
val_loss = 0.0 val_loss = 0.0
for i, batch in enumerate(val_loader): for i, batch in enumerate(val_loader):
x, y = batch_parser(batch)
x, y = model.parse_batch(batch)
y_pred = model(x) y_pred = model(x)
loss = criterion(y_pred, y) loss = criterion(y_pred, y)
reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
@ -196,11 +197,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
param_group['lr'] = learning_rate param_group['lr'] = learning_rate
model.zero_grad() model.zero_grad()
x, y = tacotron_model.parse_batch(batch)
x, y = model.parse_batch(batch)
y_pred = model(x) y_pred = model(x)
loss = criterion(y_pred, y) loss = criterion(y_pred, y)
reduced_loss = reduce_tensor(loss.data, n_gpus).item() \
if hparams.distributed_run else loss.item()
reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
if hparams.distributed_run else loss.data[0]
if hparams.fp16_run: if hparams.fp16_run:
optimizer.backward(loss) optimizer.backward(loss)
@ -208,7 +209,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
else: else:
loss.backward() loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm( grad_norm = torch.nn.utils.clip_grad_norm(
tacotron_model.parameters(), hparams.grad_clip_thresh)
model.parameters(), hparams.grad_clip_thresh)
optimizer.step() optimizer.step()
@ -225,7 +226,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
if not overflow and (iteration % hparams.iters_per_checkpoint == 0): if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
reduced_val_loss = validate( reduced_val_loss = validate(
model, criterion, valset, iteration, hparams.batch_size, model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank, tacotron_model.parse_batch)
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
if rank == 0: if rank == 0:
print("Validation loss {}: {:9f} ".format( print("Validation loss {}: {:9f} ".format(

Loading…
Cancel
Save