You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.7 KiB

  1. import tensorflow as tf
  2. from text import symbols
  3. def create_hparams(hparams_string=None, verbose=False):
  4. """Create model hyperparameters. Parse nondefault from given string."""
  5. hparams = tf.contrib.training.HParams(
  6. ################################
  7. # Experiment Parameters #
  8. ################################
  9. epochs=500,
  10. iters_per_checkpoint=500,
  11. seed=1234,
  12. dynamic_loss_scaling=True,
  13. fp16_run=False,
  14. distributed_run=False,
  15. dist_backend="nccl",
  16. dist_url="file://distributed.dpt",
  17. cudnn_enabled=True,
  18. cudnn_benchmark=False,
  19. ################################
  20. # Data Parameters #
  21. ################################
  22. training_files='ljs_audio_text_train_filelist.txt',
  23. validation_files='ljs_audio_text_val_filelist.txt',
  24. text_cleaners=['english_cleaners'],
  25. sort_by_length=False,
  26. ################################
  27. # Audio Parameters #
  28. ################################
  29. max_wav_value=32768.0,
  30. sampling_rate=22050,
  31. filter_length=1024,
  32. hop_length=256,
  33. win_length=1024,
  34. n_mel_channels=80,
  35. mel_fmin=0.0,
  36. mel_fmax=None, # if None, half the sampling rate
  37. ################################
  38. # Model Parameters #
  39. ################################
  40. n_symbols=len(symbols),
  41. symbols_embedding_dim=512,
  42. # Encoder parameters
  43. encoder_kernel_size=5,
  44. encoder_n_convolutions=3,
  45. encoder_embedding_dim=512,
  46. # Decoder parameters
  47. n_frames_per_step=1,
  48. decoder_rnn_dim=1024,
  49. prenet_dim=256,
  50. max_decoder_steps=1000,
  51. gate_threshold=0.6,
  52. # Attention parameters
  53. attention_rnn_dim=1024,
  54. attention_dim=128,
  55. # Location Layer parameters
  56. attention_location_n_filters=32,
  57. attention_location_kernel_size=31,
  58. # Mel-post processing network parameters
  59. postnet_embedding_dim=512,
  60. postnet_kernel_size=5,
  61. postnet_n_convolutions=5,
  62. ################################
  63. # Optimization Hyperparameters #
  64. ################################
  65. learning_rate=1e-3,
  66. weight_decay=1e-6,
  67. grad_clip_thresh=1,
  68. batch_size=48,
  69. mask_padding=False # set model's padded outputs to padded values
  70. )
  71. if hparams_string:
  72. tf.logging.info('Parsing command line hparams: %s', hparams_string)
  73. hparams.parse(hparams_string)
  74. if verbose:
  75. tf.logging.info('Final parsed hparams: %s', hparams.values())
  76. return hparams