|
|
@ -89,11 +89,18 @@ def load_model(hparams): |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def warm_start_model(checkpoint_path, model): |
|
|
|
def warm_start_model(checkpoint_path, model, ignore_layers): |
|
|
|
assert os.path.isfile(checkpoint_path) |
|
|
|
print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) |
|
|
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
model.load_state_dict(checkpoint_dict['state_dict']) |
|
|
|
model_dict = checkpoint_dict['state_dict'] |
|
|
|
if len(ignore_layers) > 0: |
|
|
|
model_dict = {k: v for k, v in model_dict.items() |
|
|
|
if k not in ignore_layers} |
|
|
|
dummy_dict = model.state_dict() |
|
|
|
dummy_dict.update(model_dict) |
|
|
|
model_dict = dummy_dict |
|
|
|
model.load_state_dict(model_dict) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
@ -189,7 +196,8 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
|
epoch_offset = 0 |
|
|
|
if checkpoint_path is not None: |
|
|
|
if warm_start: |
|
|
|
model = warm_start_model(checkpoint_path, model) |
|
|
|
model = warm_start_model( |
|
|
|
checkpoint_path, model, hparams.ignore_layers) |
|
|
|
else: |
|
|
|
model, optimizer, _learning_rate, iteration = load_checkpoint( |
|
|
|
checkpoint_path, model, optimizer) |
|
|
@ -258,7 +266,7 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument('-c', '--checkpoint_path', type=str, default=None, |
|
|
|
required=False, help='checkpoint path') |
|
|
|
parser.add_argument('--warm_start', action='store_true', |
|
|
|
help='load the model only (warm start)') |
|
|
|
help='load model weights only, ignore specified layers') |
|
|
|
parser.add_argument('--n_gpus', type=int, default=1, |
|
|
|
required=False, help='number of gpus') |
|
|
|
parser.add_argument('--rank', type=int, default=0, |
|
|
|