diff --git a/train.py b/train.py index 786d1dd..4413549 100644 --- a/train.py +++ b/train.py @@ -190,8 +190,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, if warm_start: model = warm_start_model(checkpoint_path, model) else: - model, optimizer, learning_rate, iteration = load_checkpoint( + model, optimizer, _learning_rate, iteration = load_checkpoint( checkpoint_path, model, optimizer) + if hparams.use_saved_learning_rate: + learning_rate = _learning_rate + iteration += 1 # next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader)))