You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.7 KiB

  1. import tensorflow as tf
  2. from text import symbols
  3. def create_hparams(hparams_string=None, verbose=False):
  4. """Create model hyperparameters. Parse nondefault from given string."""
  5. hparams = tf.contrib.training.HParams(
  6. ################################
  7. # Experiment Parameters #
  8. ################################
  9. epochs=500,
  10. iters_per_checkpoint=500,
  11. seed=1234,
  12. dynamic_loss_scaling=True,
  13. fp16_run=False,
  14. distributed_run=False,
  15. dist_backend="nccl",
  16. dist_url="file://distributed.dpt",
  17. cudnn_enabled=True,
  18. cudnn_benchmark=False,
  19. ################################
  20. # Data Parameters #
  21. ################################
  22. load_mel_from_disk=False,
  23. training_files='filelists/ljs_audio_text_train_filelist.txt',
  24. validation_files='filelists/ljs_audio_text_val_filelist.txt',
  25. text_cleaners=['english_cleaners'],
  26. sort_by_length=False,
  27. ################################
  28. # Audio Parameters #
  29. ################################
  30. max_wav_value=32768.0,
  31. sampling_rate=22050,
  32. filter_length=1024,
  33. hop_length=256,
  34. win_length=1024,
  35. n_mel_channels=80,
  36. mel_fmin=0.0,
  37. mel_fmax=None, # if None, half the sampling rate
  38. ################################
  39. # Model Parameters #
  40. ################################
  41. n_symbols=len(symbols),
  42. symbols_embedding_dim=512,
  43. # Encoder parameters
  44. encoder_kernel_size=5,
  45. encoder_n_convolutions=3,
  46. encoder_embedding_dim=512,
  47. # Decoder parameters
  48. n_frames_per_step=1,
  49. decoder_rnn_dim=1024,
  50. prenet_dim=256,
  51. max_decoder_steps=1000,
  52. gate_threshold=0.6,
  53. # Attention parameters
  54. attention_rnn_dim=1024,
  55. attention_dim=128,
  56. # Location Layer parameters
  57. attention_location_n_filters=32,
  58. attention_location_kernel_size=31,
  59. # Mel-post processing network parameters
  60. postnet_embedding_dim=512,
  61. postnet_kernel_size=5,
  62. postnet_n_convolutions=5,
  63. ################################
  64. # Optimization Hyperparameters #
  65. ################################
  66. learning_rate=1e-3,
  67. weight_decay=1e-6,
  68. grad_clip_thresh=1,
  69. batch_size=48,
  70. mask_padding=False # set model's padded outputs to padded values
  71. )
  72. if hparams_string:
  73. tf.logging.info('Parsing command line hparams: %s', hparams_string)
  74. hparams.parse(hparams_string)
  75. if verbose:
  76. tf.logging.info('Final parsed hparams: %s', hparams.values())
  77. return hparams