import sys
|
|
import copy
|
|
import torch
|
|
|
|
def _check_model_old_version(model):
|
|
if hasattr(model.WN[0], 'res_layers'):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def update_model(old_model):
|
|
if not _check_model_old_version(old_model):
|
|
return old_model
|
|
new_model = copy.deepcopy(old_model)
|
|
for idx in range(0, len(new_model.WN)):
|
|
wavenet = new_model.WN[idx]
|
|
wavenet.res_skip_layers = torch.nn.ModuleList()
|
|
n_channels = wavenet.n_channels
|
|
n_layers = wavenet.n_layers
|
|
for i in range(0, n_layers):
|
|
if i < n_layers - 1:
|
|
res_skip_channels = 2*n_channels
|
|
else:
|
|
res_skip_channels = n_channels
|
|
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
|
|
skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
|
|
if i < n_layers - 1:
|
|
res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
|
|
res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
|
|
res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
|
|
else:
|
|
res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
|
|
res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
|
|
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
|
wavenet.res_skip_layers.append(res_skip_layer)
|
|
del wavenet.res_layers
|
|
del wavenet.skip_layers
|
|
return new_model
|
|
|
|
if __name__ == '__main__':
|
|
old_model_path = sys.argv[1]
|
|
new_model_path = sys.argv[2]
|
|
model = torch.load(old_model_path)
|
|
model['model'] = update_model(model['model'])
|
|
torch.save(model, new_model_path)
|
|
|