You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.8 KiB

  1. import tensorflow as tf
  2. from text import symbols
  3. def create_hparams(hparams_string=None, verbose=False):
  4. """Create model hyperparameters. Parse nondefault from given string."""
  5. hparams = tf.contrib.training.HParams(
  6. ################################
  7. # Experiment Parameters #
  8. ################################
  9. epochs=500,
  10. iters_per_checkpoint=1000,
  11. seed=1234,
  12. dynamic_loss_scaling=True,
  13. fp16_run=False,
  14. distributed_run=False,
  15. dist_backend="nccl",
  16. dist_url="tcp://localhost:54321",
  17. cudnn_enabled=True,
  18. cudnn_benchmark=False,
  19. ignore_layers=['embedding.weight'],
  20. ################################
  21. # Data Parameters #
  22. ################################
  23. load_mel_from_disk=False,
  24. training_files='filelists/ljs_audio_text_train_filelist.txt',
  25. validation_files='filelists/ljs_audio_text_val_filelist.txt',
  26. text_cleaners=['english_cleaners'],
  27. ################################
  28. # Audio Parameters #
  29. ################################
  30. max_wav_value=32768.0,
  31. sampling_rate=22050,
  32. filter_length=1024,
  33. hop_length=256,
  34. win_length=1024,
  35. n_mel_channels=80,
  36. mel_fmin=0.0,
  37. mel_fmax=8000.0,
  38. ################################
  39. # Model Parameters #
  40. ################################
  41. n_symbols=len(symbols),
  42. symbols_embedding_dim=512,
  43. # Encoder parameters
  44. encoder_kernel_size=5,
  45. encoder_n_convolutions=3,
  46. encoder_embedding_dim=512,
  47. # Decoder parameters
  48. n_frames_per_step=1, # currently only 1 is supported
  49. decoder_rnn_dim=1024,
  50. prenet_dim=256,
  51. max_decoder_steps=1000,
  52. gate_threshold=0.5,
  53. p_attention_dropout=0.1,
  54. p_decoder_dropout=0.1,
  55. # Attention parameters
  56. attention_rnn_dim=1024,
  57. attention_dim=128,
  58. # Location Layer parameters
  59. attention_location_n_filters=32,
  60. attention_location_kernel_size=31,
  61. # Mel-post processing network parameters
  62. postnet_embedding_dim=512,
  63. postnet_kernel_size=5,
  64. postnet_n_convolutions=5,
  65. ################################
  66. # Optimization Hyperparameters #
  67. ################################
  68. use_saved_learning_rate=False,
  69. learning_rate=1e-3,
  70. weight_decay=1e-6,
  71. grad_clip_thresh=1.0,
  72. batch_size=64,
  73. mask_padding=True # set model's padded outputs to padded values
  74. )
  75. if hparams_string:
  76. tf.logging.info('Parsing command line hparams: %s', hparams_string)
  77. hparams.parse(hparams_string)
  78. if verbose:
  79. tf.logging.info('Final parsed hparams: %s', hparams.values())
  80. return hparams