You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

286 lines
11 KiB

  1. import os
  2. import time
  3. import argparse
  4. import math
  5. import torch
  6. from distributed import DistributedDataParallel
  7. from torch.utils.data.distributed import DistributedSampler
  8. from torch.nn import DataParallel
  9. from torch.utils.data import DataLoader
  10. from fp16_optimizer import FP16_Optimizer
  11. from model import Tacotron2
  12. from data_utils import TextMelLoader, TextMelCollate
  13. from loss_function import Tacotron2Loss
  14. from logger import Tacotron2Logger
  15. from hparams import create_hparams
  16. def batchnorm_to_float(module):
  17. """Converts batch norm modules to FP32"""
  18. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  19. module.float()
  20. for child in module.children():
  21. batchnorm_to_float(child)
  22. return module
  23. def reduce_tensor(tensor, num_gpus):
  24. rt = tensor.clone()
  25. torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM)
  26. rt /= num_gpus
  27. return rt
  28. def init_distributed(hparams, n_gpus, rank, group_name):
  29. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  30. print("Initializing distributed")
  31. # Set cuda device so everything is done on the right GPU.
  32. torch.cuda.set_device(rank % torch.cuda.device_count())
  33. # Initialize distributed communication
  34. torch.distributed.init_process_group(
  35. backend=hparams.dist_backend, init_method=hparams.dist_url,
  36. world_size=n_gpus, rank=rank, group_name=group_name)
  37. print("Done initializing distributed")
  38. def prepare_dataloaders(hparams):
  39. # Get data, data loaders and collate function ready
  40. trainset = TextMelLoader(hparams.training_files, hparams)
  41. valset = TextMelLoader(hparams.validation_files, hparams)
  42. collate_fn = TextMelCollate(hparams.n_frames_per_step)
  43. train_sampler = DistributedSampler(trainset) \
  44. if hparams.distributed_run else None
  45. train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
  46. sampler=train_sampler,
  47. batch_size=hparams.batch_size, pin_memory=False,
  48. drop_last=True, collate_fn=collate_fn)
  49. return train_loader, valset, collate_fn
  50. def prepare_directories_and_logger(output_directory, log_directory, rank):
  51. if rank == 0:
  52. if not os.path.isdir(output_directory):
  53. os.makedirs(output_directory)
  54. os.chmod(output_directory, 0o775)
  55. logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
  56. else:
  57. logger = None
  58. return logger
  59. def load_model(hparams):
  60. model = Tacotron2(hparams).cuda()
  61. model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
  62. if hparams.distributed_run:
  63. model = DistributedDataParallel(model)
  64. elif torch.cuda.device_count() > 1:
  65. model = DataParallel(model)
  66. return model
  67. def warm_start_model(checkpoint_path, model):
  68. assert os.path.isfile(checkpoint_path)
  69. print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
  70. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  71. model.load_state_dict(checkpoint_dict['state_dict'])
  72. return model
  73. def load_checkpoint(checkpoint_path, model, optimizer):
  74. assert os.path.isfile(checkpoint_path)
  75. print("Loading checkpoint '{}'".format(checkpoint_path))
  76. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  77. model.load_state_dict(checkpoint_dict['state_dict'])
  78. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  79. learning_rate = checkpoint_dict['learning_rate']
  80. iteration = checkpoint_dict['iteration']
  81. print("Loaded checkpoint '{}' from iteration {}" .format(
  82. checkpoint_path, iteration))
  83. return model, optimizer, learning_rate, iteration
  84. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  85. print("Saving model and optimizer state at iteration {} to {}".format(
  86. iteration, filepath))
  87. torch.save({'iteration': iteration,
  88. 'state_dict': model.state_dict(),
  89. 'optimizer': optimizer.state_dict(),
  90. 'learning_rate': learning_rate}, filepath)
  91. def validate(model, criterion, valset, iteration, batch_size, n_gpus,
  92. collate_fn, logger, distributed_run, rank):
  93. """Handles all the validation scoring and printing"""
  94. model.eval()
  95. with torch.no_grad():
  96. val_sampler = DistributedSampler(valset) if distributed_run else None
  97. val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
  98. shuffle=False, batch_size=batch_size,
  99. pin_memory=False, collate_fn=collate_fn)
  100. val_loss = 0.0
  101. if distributed_run or torch.cuda.device_count() > 1:
  102. batch_parser = model.module.parse_batch
  103. else:
  104. batch_parser = model.parse_batch
  105. for i, batch in enumerate(val_loader):
  106. x, y = batch_parser(batch)
  107. y_pred = model(x)
  108. loss = criterion(y_pred, y)
  109. reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \
  110. if distributed_run else loss.data[0]
  111. val_loss += reduced_val_loss
  112. val_loss = val_loss / (i + 1)
  113. model.train()
  114. return val_loss
  115. def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
  116. rank, group_name, hparams):
  117. """Training and validation logging results to tensorboard and stdout
  118. Params
  119. ------
  120. output_directory (string): directory to save checkpoints
  121. log_directory (string) directory to save tensorboard logs
  122. checkpoint_path(string): checkpoint path
  123. n_gpus (int): number of gpus
  124. rank (int): rank of current gpu
  125. hparams (object): comma separated list of "name=value" pairs.
  126. """
  127. if hparams.distributed_run:
  128. init_distributed(hparams, n_gpus, rank, group_name)
  129. torch.manual_seed(hparams.seed)
  130. torch.cuda.manual_seed(hparams.seed)
  131. model = load_model(hparams)
  132. learning_rate = hparams.learning_rate
  133. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
  134. weight_decay=hparams.weight_decay)
  135. if hparams.fp16_run:
  136. optimizer = FP16_Optimizer(
  137. optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
  138. criterion = Tacotron2Loss()
  139. logger = prepare_directories_and_logger(
  140. output_directory, log_directory, rank)
  141. train_loader, valset, collate_fn = prepare_dataloaders(hparams)
  142. # Load checkpoint if one exists
  143. iteration = 0
  144. epoch_offset = 0
  145. if checkpoint_path is not None:
  146. if warm_start:
  147. model = warm_start_model(checkpoint_path, model)
  148. else:
  149. model, optimizer, learning_rate, iteration = load_checkpoint(
  150. checkpoint_path, model, optimizer)
  151. iteration += 1 # next iteration is iteration + 1
  152. epoch_offset = max(0, int(iteration / len(train_loader)))
  153. model.train()
  154. if distributed_run or torch.cuda.device_count() > 1:
  155. batch_parser = model.module.parse_batch
  156. else:
  157. batch_parser = model.parse_batch
  158. # ================ MAIN TRAINNIG LOOP! ===================
  159. for epoch in range(epoch_offset, hparams.epochs):
  160. print("Epoch: {}".format(epoch))
  161. for i, batch in enumerate(train_loader):
  162. start = time.perf_counter()
  163. for param_group in optimizer.param_groups:
  164. param_group['lr'] = learning_rate
  165. model.zero_grad()
  166. x, y = batch_parser(batch)
  167. y_pred = model(x)
  168. loss = criterion(y_pred, y)
  169. reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \
  170. if hparams.distributed_run else loss.data[0]
  171. if hparams.fp16_run:
  172. optimizer.backward(loss)
  173. grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
  174. else:
  175. loss.backward()
  176. grad_norm = torch.nn.utils.clip_grad_norm(
  177. model.parameters(), hparams.grad_clip_thresh)
  178. optimizer.step()
  179. overflow = optimizer.overflow if hparams.fp16_run else False
  180. if not overflow and not math.isnan(reduced_loss) and rank == 0:
  181. duration = time.perf_counter() - start
  182. print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
  183. iteration, reduced_loss, grad_norm, duration))
  184. logger.log_training(
  185. reduced_loss, grad_norm, learning_rate, duration, iteration)
  186. if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
  187. reduced_val_loss = validate(
  188. model, criterion, valset, iteration, hparams.batch_size,
  189. n_gpus, collate_fn, logger, hparams.distributed_run, rank)
  190. if rank == 0:
  191. print("Validation loss {}: {:9f} ".format(
  192. iteration, reduced_val_loss))
  193. logger.log_validation(
  194. reduced_val_loss, model, y, y_pred, iteration)
  195. checkpoint_path = os.path.join(
  196. output_directory, "checkpoint_{}".format(iteration))
  197. save_checkpoint(model, optimizer, learning_rate, iteration,
  198. checkpoint_path)
  199. iteration += 1
  200. if __name__ == '__main__':
  201. parser = argparse.ArgumentParser()
  202. parser.add_argument('-o', '--output_directory', type=str,
  203. help='directory to save checkpoints')
  204. parser.add_argument('-l', '--log_directory', type=str,
  205. help='directory to save tensorboard logs')
  206. parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
  207. required=False, help='checkpoint path')
  208. parser.add_argument('--warm_start', action='store_true',
  209. help='load the model only (warm start)')
  210. parser.add_argument('--n_gpus', type=int, default=1,
  211. required=False, help='number of gpus')
  212. parser.add_argument('--rank', type=int, default=0,
  213. required=False, help='rank of current gpu')
  214. parser.add_argument('--group_name', type=str, default='group_name',
  215. required=False, help='Distributed group name')
  216. parser.add_argument('--hparams', type=str,
  217. required=False, help='comma separated name=value pairs')
  218. args = parser.parse_args()
  219. hparams = create_hparams(args.hparams)
  220. torch.backends.cudnn.enabled = hparams.cudnn_enabled
  221. torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
  222. print("FP16 Run:", hparams.fp16_run)
  223. print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
  224. print("Distributed Run:", hparams.distributed_run)
  225. print("cuDNN Enabled:", hparams.cudnn_enabled)
  226. print("cuDNN Benchmark:", hparams.cudnn_benchmark)
  227. train(args.output_directory, args.log_directory, args.checkpoint_path,
  228. args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)