You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

286 lines
11 KiB

  1. import os
  2. import time
  3. import argparse
  4. import math
  5. from numpy import finfo
  6. import torch
  7. from distributed import apply_gradient_allreduce
  8. import torch.distributed as dist
  9. from torch.utils.data.distributed import DistributedSampler
  10. from torch.utils.data import DataLoader
  11. from model import Tacotron2
  12. from data_utils import TextMelLoader, TextMelCollate
  13. from loss_function import Tacotron2Loss
  14. from logger import Tacotron2Logger
  15. from hparams import create_hparams
  16. def reduce_tensor(tensor, n_gpus):
  17. rt = tensor.clone()
  18. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  19. rt /= n_gpus
  20. return rt
  21. def init_distributed(hparams, n_gpus, rank, group_name):
  22. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  23. print("Initializing Distributed")
  24. # Set cuda device so everything is done on the right GPU.
  25. torch.cuda.set_device(rank % torch.cuda.device_count())
  26. # Initialize distributed communication
  27. dist.init_process_group(
  28. backend=hparams.dist_backend, init_method=hparams.dist_url,
  29. world_size=n_gpus, rank=rank, group_name=group_name)
  30. print("Done initializing distributed")
  31. def prepare_dataloaders(hparams):
  32. # Get data, data loaders and collate function ready
  33. trainset = TextMelLoader(hparams.training_files, hparams)
  34. valset = TextMelLoader(hparams.validation_files, hparams)
  35. collate_fn = TextMelCollate(hparams.n_frames_per_step)
  36. train_sampler = DistributedSampler(trainset) \
  37. if hparams.distributed_run else None
  38. train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
  39. sampler=train_sampler,
  40. batch_size=hparams.batch_size, pin_memory=False,
  41. drop_last=True, collate_fn=collate_fn)
  42. return train_loader, valset, collate_fn
  43. def prepare_directories_and_logger(output_directory, log_directory, rank):
  44. if rank == 0:
  45. if not os.path.isdir(output_directory):
  46. os.makedirs(output_directory)
  47. os.chmod(output_directory, 0o775)
  48. logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
  49. else:
  50. logger = None
  51. return logger
  52. def load_model(hparams):
  53. model = Tacotron2(hparams).cuda()
  54. if hparams.fp16_run:
  55. model.decoder.attention_layer.score_mask_value = finfo('float16').min
  56. if hparams.distributed_run:
  57. model = apply_gradient_allreduce(model)
  58. return model
  59. def warm_start_model(checkpoint_path, model, ignore_layers):
  60. assert os.path.isfile(checkpoint_path)
  61. print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
  62. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  63. model_dict = checkpoint_dict['state_dict']
  64. if len(ignore_layers) > 0:
  65. model_dict = {k: v for k, v in model_dict.items()
  66. if k not in ignore_layers}
  67. dummy_dict = model.state_dict()
  68. dummy_dict.update(model_dict)
  69. model_dict = dummy_dict
  70. model.load_state_dict(model_dict)
  71. return model
  72. def load_checkpoint(checkpoint_path, model, optimizer):
  73. assert os.path.isfile(checkpoint_path)
  74. print("Loading checkpoint '{}'".format(checkpoint_path))
  75. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  76. model.load_state_dict(checkpoint_dict['state_dict'])
  77. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  78. learning_rate = checkpoint_dict['learning_rate']
  79. iteration = checkpoint_dict['iteration']
  80. print("Loaded checkpoint '{}' from iteration {}" .format(
  81. checkpoint_path, iteration))
  82. return model, optimizer, learning_rate, iteration
  83. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  84. print("Saving model and optimizer state at iteration {} to {}".format(
  85. iteration, filepath))
  86. torch.save({'iteration': iteration,
  87. 'state_dict': model.state_dict(),
  88. 'optimizer': optimizer.state_dict(),
  89. 'learning_rate': learning_rate}, filepath)
  90. def validate(model, criterion, valset, iteration, batch_size, n_gpus,
  91. collate_fn, logger, distributed_run, rank):
  92. """Handles all the validation scoring and printing"""
  93. model.eval()
  94. with torch.no_grad():
  95. val_sampler = DistributedSampler(valset) if distributed_run else None
  96. val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
  97. shuffle=False, batch_size=batch_size,
  98. pin_memory=False, collate_fn=collate_fn)
  99. val_loss = 0.0
  100. for i, batch in enumerate(val_loader):
  101. x, y = model.parse_batch(batch)
  102. y_pred = model(x)
  103. loss = criterion(y_pred, y)
  104. if distributed_run:
  105. reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
  106. else:
  107. reduced_val_loss = loss.item()
  108. val_loss += reduced_val_loss
  109. val_loss = val_loss / (i + 1)
  110. model.train()
  111. if rank == 0:
  112. print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
  113. logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
  114. def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
  115. rank, group_name, hparams):
  116. """Training and validation logging results to tensorboard and stdout
  117. Params
  118. ------
  119. output_directory (string): directory to save checkpoints
  120. log_directory (string) directory to save tensorboard logs
  121. checkpoint_path(string): checkpoint path
  122. n_gpus (int): number of gpus
  123. rank (int): rank of current gpu
  124. hparams (object): comma separated list of "name=value" pairs.
  125. """
  126. if hparams.distributed_run:
  127. init_distributed(hparams, n_gpus, rank, group_name)
  128. torch.manual_seed(hparams.seed)
  129. torch.cuda.manual_seed(hparams.seed)
  130. model = load_model(hparams)
  131. learning_rate = hparams.learning_rate
  132. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
  133. weight_decay=hparams.weight_decay)
  134. if hparams.fp16_run:
  135. from apex import amp
  136. model, optimizer = amp.initialize(
  137. model, optimizer, opt_level='O2')
  138. if hparams.distributed_run:
  139. model = apply_gradient_allreduce(model)
  140. criterion = Tacotron2Loss()
  141. logger = prepare_directories_and_logger(
  142. output_directory, log_directory, rank)
  143. train_loader, valset, collate_fn = prepare_dataloaders(hparams)
  144. # Load checkpoint if one exists
  145. iteration = 0
  146. epoch_offset = 0
  147. if checkpoint_path is not None:
  148. if warm_start:
  149. model = warm_start_model(
  150. checkpoint_path, model, hparams.ignore_layers)
  151. else:
  152. model, optimizer, _learning_rate, iteration = load_checkpoint(
  153. checkpoint_path, model, optimizer)
  154. if hparams.use_saved_learning_rate:
  155. learning_rate = _learning_rate
  156. iteration += 1 # next iteration is iteration + 1
  157. epoch_offset = max(0, int(iteration / len(train_loader)))
  158. model.train()
  159. is_overflow = False
  160. # ================ MAIN TRAINNIG LOOP! ===================
  161. for epoch in range(epoch_offset, hparams.epochs):
  162. print("Epoch: {}".format(epoch))
  163. for i, batch in enumerate(train_loader):
  164. start = time.perf_counter()
  165. for param_group in optimizer.param_groups:
  166. param_group['lr'] = learning_rate
  167. model.zero_grad()
  168. x, y = model.parse_batch(batch)
  169. y_pred = model(x)
  170. loss = criterion(y_pred, y)
  171. if hparams.distributed_run:
  172. reduced_loss = reduce_tensor(loss.data, n_gpus).item()
  173. else:
  174. reduced_loss = loss.item()
  175. if hparams.fp16_run:
  176. with amp.scale_loss(loss, optimizer) as scaled_loss:
  177. scaled_loss.backward()
  178. else:
  179. loss.backward()
  180. if hparams.fp16_run:
  181. grad_norm = torch.nn.utils.clip_grad_norm_(
  182. amp.master_params(optimizer), hparams.grad_clip_thresh)
  183. is_overflow = math.isnan(grad_norm)
  184. else:
  185. grad_norm = torch.nn.utils.clip_grad_norm_(
  186. model.parameters(), hparams.grad_clip_thresh)
  187. optimizer.step()
  188. if not is_overflow and rank == 0:
  189. duration = time.perf_counter() - start
  190. print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
  191. iteration, reduced_loss, grad_norm, duration))
  192. logger.log_training(
  193. reduced_loss, grad_norm, learning_rate, duration, iteration)
  194. if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
  195. validate(model, criterion, valset, iteration,
  196. hparams.batch_size, n_gpus, collate_fn, logger,
  197. hparams.distributed_run, rank)
  198. if rank == 0:
  199. checkpoint_path = os.path.join(
  200. output_directory, "checkpoint_{}".format(iteration))
  201. save_checkpoint(model, optimizer, learning_rate, iteration,
  202. checkpoint_path)
  203. iteration += 1
  204. if __name__ == '__main__':
  205. parser = argparse.ArgumentParser()
  206. parser.add_argument('-o', '--output_directory', type=str,
  207. help='directory to save checkpoints')
  208. parser.add_argument('-l', '--log_directory', type=str,
  209. help='directory to save tensorboard logs')
  210. parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
  211. required=False, help='checkpoint path')
  212. parser.add_argument('--warm_start', action='store_true',
  213. help='load model weights only, ignore specified layers')
  214. parser.add_argument('--n_gpus', type=int, default=1,
  215. required=False, help='number of gpus')
  216. parser.add_argument('--rank', type=int, default=0,
  217. required=False, help='rank of current gpu')
  218. parser.add_argument('--group_name', type=str, default='group_name',
  219. required=False, help='Distributed group name')
  220. parser.add_argument('--hparams', type=str,
  221. required=False, help='comma separated name=value pairs')
  222. args = parser.parse_args()
  223. hparams = create_hparams(args.hparams)
  224. torch.backends.cudnn.enabled = hparams.cudnn_enabled
  225. torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
  226. print("FP16 Run:", hparams.fp16_run)
  227. print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
  228. print("Distributed Run:", hparams.distributed_run)
  229. print("cuDNN Enabled:", hparams.cudnn_enabled)
  230. print("cuDNN Benchmark:", hparams.cudnn_benchmark)
  231. train(args.output_directory, args.log_directory, args.checkpoint_path,
  232. args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)