Browse Source

train.py: adding routine to warm start and ignore layers, e.g. embedding.weight

master
rafaelvalle 5 years ago
parent
commit
3869781877
1 changed files with 12 additions and 4 deletions
  1. +12
    -4
      train.py

+ 12
- 4
train.py View File

@ -89,11 +89,18 @@ def load_model(hparams):
return model return model
def warm_start_model(checkpoint_path, model):
def warm_start_model(checkpoint_path, model, ignore_layers):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint_dict['state_dict'])
model_dict = checkpoint_dict['state_dict']
if len(ignore_layers) > 0:
model_dict = {k: v for k, v in model_dict.items()
if k not in ignore_layers}
dummy_dict = model.state_dict()
dummy_dict.update(model_dict)
model_dict = dummy_dict
model.load_state_dict(model_dict)
return model return model
@ -189,7 +196,8 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
epoch_offset = 0 epoch_offset = 0
if checkpoint_path is not None: if checkpoint_path is not None:
if warm_start: if warm_start:
model = warm_start_model(checkpoint_path, model)
model = warm_start_model(
checkpoint_path, model, hparams.ignore_layers)
else: else:
model, optimizer, _learning_rate, iteration = load_checkpoint( model, optimizer, _learning_rate, iteration = load_checkpoint(
checkpoint_path, model, optimizer) checkpoint_path, model, optimizer)
@ -258,7 +266,7 @@ if __name__ == '__main__':
parser.add_argument('-c', '--checkpoint_path', type=str, default=None, parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
required=False, help='checkpoint path') required=False, help='checkpoint path')
parser.add_argument('--warm_start', action='store_true', parser.add_argument('--warm_start', action='store_true',
help='load the model only (warm start)')
help='load model weights only, ignore specified layers')
parser.add_argument('--n_gpus', type=int, default=1, parser.add_argument('--n_gpus', type=int, default=1,
required=False, help='number of gpus') required=False, help='number of gpus')
parser.add_argument('--rank', type=int, default=0, parser.add_argument('--rank', type=int, default=0,

Loading…
Cancel
Save