You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

292 lines
11 KiB

  1. import os
  2. import time
  3. import argparse
  4. import math
  5. from numpy import finfo
  6. import torch
  7. from distributed import apply_gradient_allreduce
  8. import torch.distributed as dist
  9. from torch.utils.data.distributed import DistributedSampler
  10. from torch.utils.data import DataLoader
  11. from fp16_optimizer import FP16_Optimizer
  12. from model import Tacotron2
  13. from data_utils import TextMelLoader, TextMelCollate
  14. from loss_function import Tacotron2Loss
  15. from logger import Tacotron2Logger
  16. from hparams import create_hparams
  17. def batchnorm_to_float(module):
  18. """Converts batch norm modules to FP32"""
  19. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  20. module.float()
  21. for child in module.children():
  22. batchnorm_to_float(child)
  23. return module
  24. def reduce_tensor(tensor, n_gpus):
  25. rt = tensor.clone()
  26. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  27. rt /= n_gpus
  28. return rt
  29. def init_distributed(hparams, n_gpus, rank, group_name):
  30. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  31. print("Initializing Distributed")
  32. # Set cuda device so everything is done on the right GPU.
  33. torch.cuda.set_device(rank % torch.cuda.device_count())
  34. # Initialize distributed communication
  35. dist.init_process_group(
  36. backend=hparams.dist_backend, init_method=hparams.dist_url,
  37. world_size=n_gpus, rank=rank, group_name=group_name)
  38. print("Done initializing distributed")
  39. def prepare_dataloaders(hparams):
  40. # Get data, data loaders and collate function ready
  41. trainset = TextMelLoader(hparams.training_files, hparams)
  42. valset = TextMelLoader(hparams.validation_files, hparams)
  43. collate_fn = TextMelCollate(hparams.n_frames_per_step)
  44. train_sampler = DistributedSampler(trainset) \
  45. if hparams.distributed_run else None
  46. train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
  47. sampler=train_sampler,
  48. batch_size=hparams.batch_size, pin_memory=False,
  49. drop_last=True, collate_fn=collate_fn)
  50. return train_loader, valset, collate_fn
  51. def prepare_directories_and_logger(output_directory, log_directory, rank):
  52. if rank == 0:
  53. if not os.path.isdir(output_directory):
  54. os.makedirs(output_directory)
  55. os.chmod(output_directory, 0o775)
  56. logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
  57. else:
  58. logger = None
  59. return logger
  60. def load_model(hparams):
  61. model = Tacotron2(hparams).cuda()
  62. if hparams.fp16_run:
  63. model = batchnorm_to_float(model.half())
  64. model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
  65. if hparams.distributed_run:
  66. model = apply_gradient_allreduce(model)
  67. return model
  68. def warm_start_model(checkpoint_path, model, ignore_layers):
  69. assert os.path.isfile(checkpoint_path)
  70. print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
  71. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  72. model_dict = checkpoint_dict['state_dict']
  73. if len(ignore_layers) > 0:
  74. model_dict = {k: v for k, v in model_dict.items()
  75. if k not in ignore_layers}
  76. dummy_dict = model.state_dict()
  77. dummy_dict.update(model_dict)
  78. model_dict = dummy_dict
  79. model.load_state_dict(model_dict)
  80. return model
  81. def load_checkpoint(checkpoint_path, model, optimizer):
  82. assert os.path.isfile(checkpoint_path)
  83. print("Loading checkpoint '{}'".format(checkpoint_path))
  84. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  85. model.load_state_dict(checkpoint_dict['state_dict'])
  86. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  87. learning_rate = checkpoint_dict['learning_rate']
  88. iteration = checkpoint_dict['iteration']
  89. print("Loaded checkpoint '{}' from iteration {}" .format(
  90. checkpoint_path, iteration))
  91. return model, optimizer, learning_rate, iteration
  92. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  93. print("Saving model and optimizer state at iteration {} to {}".format(
  94. iteration, filepath))
  95. torch.save({'iteration': iteration,
  96. 'state_dict': model.state_dict(),
  97. 'optimizer': optimizer.state_dict(),
  98. 'learning_rate': learning_rate}, filepath)
  99. def validate(model, criterion, valset, iteration, batch_size, n_gpus,
  100. collate_fn, logger, distributed_run, rank):
  101. """Handles all the validation scoring and printing"""
  102. model.eval()
  103. with torch.no_grad():
  104. val_sampler = DistributedSampler(valset) if distributed_run else None
  105. val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
  106. shuffle=False, batch_size=batch_size,
  107. pin_memory=False, collate_fn=collate_fn)
  108. val_loss = 0.0
  109. for i, batch in enumerate(val_loader):
  110. x, y = model.parse_batch(batch)
  111. y_pred = model(x)
  112. loss = criterion(y_pred, y)
  113. if distributed_run:
  114. reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
  115. else:
  116. reduced_val_loss = loss.item()
  117. val_loss += reduced_val_loss
  118. val_loss = val_loss / (i + 1)
  119. model.train()
  120. if rank == 0:
  121. print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
  122. logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
  123. def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
  124. rank, group_name, hparams):
  125. """Training and validation logging results to tensorboard and stdout
  126. Params
  127. ------
  128. output_directory (string): directory to save checkpoints
  129. log_directory (string) directory to save tensorboard logs
  130. checkpoint_path(string): checkpoint path
  131. n_gpus (int): number of gpus
  132. rank (int): rank of current gpu
  133. hparams (object): comma separated list of "name=value" pairs.
  134. """
  135. if hparams.distributed_run:
  136. init_distributed(hparams, n_gpus, rank, group_name)
  137. torch.manual_seed(hparams.seed)
  138. torch.cuda.manual_seed(hparams.seed)
  139. model = load_model(hparams)
  140. learning_rate = hparams.learning_rate
  141. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
  142. weight_decay=hparams.weight_decay)
  143. if hparams.fp16_run:
  144. optimizer = FP16_Optimizer(
  145. optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
  146. if hparams.distributed_run:
  147. model = apply_gradient_allreduce(model)
  148. criterion = Tacotron2Loss()
  149. logger = prepare_directories_and_logger(
  150. output_directory, log_directory, rank)
  151. train_loader, valset, collate_fn = prepare_dataloaders(hparams)
  152. # Load checkpoint if one exists
  153. iteration = 0
  154. epoch_offset = 0
  155. if checkpoint_path is not None:
  156. if warm_start:
  157. model = warm_start_model(
  158. checkpoint_path, model, hparams.ignore_layers)
  159. else:
  160. model, optimizer, _learning_rate, iteration = load_checkpoint(
  161. checkpoint_path, model, optimizer)
  162. if hparams.use_saved_learning_rate:
  163. learning_rate = _learning_rate
  164. iteration += 1 # next iteration is iteration + 1
  165. epoch_offset = max(0, int(iteration / len(train_loader)))
  166. model.train()
  167. # ================ MAIN TRAINNIG LOOP! ===================
  168. for epoch in range(epoch_offset, hparams.epochs):
  169. print("Epoch: {}".format(epoch))
  170. for i, batch in enumerate(train_loader):
  171. start = time.perf_counter()
  172. for param_group in optimizer.param_groups:
  173. param_group['lr'] = learning_rate
  174. model.zero_grad()
  175. x, y = model.parse_batch(batch)
  176. y_pred = model(x)
  177. loss = criterion(y_pred, y)
  178. if hparams.distributed_run:
  179. reduced_loss = reduce_tensor(loss.data, n_gpus).item()
  180. else:
  181. reduced_loss = loss.item()
  182. if hparams.fp16_run:
  183. optimizer.backward(loss)
  184. grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
  185. else:
  186. loss.backward()
  187. grad_norm = torch.nn.utils.clip_grad_norm_(
  188. model.parameters(), hparams.grad_clip_thresh)
  189. optimizer.step()
  190. overflow = optimizer.overflow if hparams.fp16_run else False
  191. if not overflow and not math.isnan(reduced_loss) and rank == 0:
  192. duration = time.perf_counter() - start
  193. print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
  194. iteration, reduced_loss, grad_norm, duration))
  195. logger.log_training(
  196. reduced_loss, grad_norm, learning_rate, duration, iteration)
  197. if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
  198. validate(model, criterion, valset, iteration,
  199. hparams.batch_size, n_gpus, collate_fn, logger,
  200. hparams.distributed_run, rank)
  201. if rank == 0:
  202. checkpoint_path = os.path.join(
  203. output_directory, "checkpoint_{}".format(iteration))
  204. save_checkpoint(model, optimizer, learning_rate, iteration,
  205. checkpoint_path)
  206. iteration += 1
  207. if __name__ == '__main__':
  208. parser = argparse.ArgumentParser()
  209. parser.add_argument('-o', '--output_directory', type=str,
  210. help='directory to save checkpoints')
  211. parser.add_argument('-l', '--log_directory', type=str,
  212. help='directory to save tensorboard logs')
  213. parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
  214. required=False, help='checkpoint path')
  215. parser.add_argument('--warm_start', action='store_true',
  216. help='load model weights only, ignore specified layers')
  217. parser.add_argument('--n_gpus', type=int, default=1,
  218. required=False, help='number of gpus')
  219. parser.add_argument('--rank', type=int, default=0,
  220. required=False, help='rank of current gpu')
  221. parser.add_argument('--group_name', type=str, default='group_name',
  222. required=False, help='Distributed group name')
  223. parser.add_argument('--hparams', type=str,
  224. required=False, help='comma separated name=value pairs')
  225. args = parser.parse_args()
  226. hparams = create_hparams(args.hparams)
  227. torch.backends.cudnn.enabled = hparams.cudnn_enabled
  228. torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
  229. print("FP16 Run:", hparams.fp16_run)
  230. print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
  231. print("Distributed Run:", hparams.distributed_run)
  232. print("cuDNN Enabled:", hparams.cudnn_enabled)
  233. print("cuDNN Benchmark:", hparams.cudnn_benchmark)
  234. train(args.output_directory, args.log_directory, args.checkpoint_path,
  235. args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)