You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
2.8 KiB

  1. import tensorflow as tf
  2. from text import symbols
  3. def create_hparams(hparams_string=None, verbose=False):
  4. """Create model hyperparameters. Parse nondefault from given string."""
  5. hparams = tf.contrib.training.HParams(
  6. ################################
  7. # Experiment Parameters #
  8. ################################
  9. epochs=500,
  10. iters_per_checkpoint=1000,
  11. seed=1234,
  12. dynamic_loss_scaling=True,
  13. fp16_run=False,
  14. distributed_run=False,
  15. dist_backend="nccl",
  16. dist_url="file://distributed.dpt",
  17. cudnn_enabled=True,
  18. cudnn_benchmark=False,
  19. ################################
  20. # Data Parameters #
  21. ################################
  22. load_mel_from_disk=False,
  23. training_files='filelists/ljs_audio_text_train_filelist.txt',
  24. validation_files='filelists/ljs_audio_text_val_filelist.txt',
  25. text_cleaners=['english_cleaners'],
  26. ################################
  27. # Audio Parameters #
  28. ################################
  29. max_wav_value=32768.0,
  30. sampling_rate=22050,
  31. filter_length=1024,
  32. hop_length=256,
  33. win_length=1024,
  34. n_mel_channels=80,
  35. mel_fmin=0.0,
  36. mel_fmax=8000.0,
  37. ################################
  38. # Model Parameters #
  39. ################################
  40. n_symbols=len(symbols),
  41. symbols_embedding_dim=512,
  42. # Encoder parameters
  43. encoder_kernel_size=5,
  44. encoder_n_convolutions=3,
  45. encoder_embedding_dim=512,
  46. # Decoder parameters
  47. n_frames_per_step=1, # currently only 1 is supported
  48. decoder_rnn_dim=1024,
  49. prenet_dim=256,
  50. max_decoder_steps=1000,
  51. gate_threshold=0.5,
  52. p_attention_dropout=0.1,
  53. p_decoder_dropout=0.1,
  54. # Attention parameters
  55. attention_rnn_dim=1024,
  56. attention_dim=128,
  57. # Location Layer parameters
  58. attention_location_n_filters=32,
  59. attention_location_kernel_size=31,
  60. # Mel-post processing network parameters
  61. postnet_embedding_dim=512,
  62. postnet_kernel_size=5,
  63. postnet_n_convolutions=5,
  64. ################################
  65. # Optimization Hyperparameters #
  66. ################################
  67. use_saved_learning_rate=False,
  68. learning_rate=1e-3,
  69. weight_decay=1e-6,
  70. grad_clip_thresh=1.0,
  71. batch_size=64,
  72. mask_padding=True # set model's padded outputs to padded values
  73. )
  74. if hparams_string:
  75. tf.logging.info('Parsing command line hparams: %s', hparams_string)
  76. hparams.parse(hparams_string)
  77. if verbose:
  78. tf.logging.info('Final parsed hparams: %s', hparams.values())
  79. return hparams