diff --git a/train.py b/train.py index cd7635f..1035d5c 100644 --- a/train.py +++ b/train.py @@ -89,11 +89,18 @@ def load_model(hparams): return model -def warm_start_model(checkpoint_path, model): +def warm_start_model(checkpoint_path, model, ignore_layers): assert os.path.isfile(checkpoint_path) print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') - model.load_state_dict(checkpoint_dict['state_dict']) + model_dict = checkpoint_dict['state_dict'] + if len(ignore_layers) > 0: + model_dict = {k: v for k, v in model_dict.items() + if k not in ignore_layers} + dummy_dict = model.state_dict() + dummy_dict.update(model_dict) + model_dict = dummy_dict + model.load_state_dict(model_dict) return model @@ -189,7 +196,8 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = 0 if checkpoint_path is not None: if warm_start: - model = warm_start_model(checkpoint_path, model) + model = warm_start_model( + checkpoint_path, model, hparams.ignore_layers) else: model, optimizer, _learning_rate, iteration = load_checkpoint( checkpoint_path, model, optimizer) @@ -258,7 +266,7 @@ if __name__ == '__main__': parser.add_argument('-c', '--checkpoint_path', type=str, default=None, required=False, help='checkpoint path') parser.add_argument('--warm_start', action='store_true', - help='load the model only (warm start)') + help='load model weights only, ignore specified layers') parser.add_argument('--n_gpus', type=int, default=1, required=False, help='number of gpus') parser.add_argument('--rank', type=int, default=0,