|
|
- import tensorflow as tf
- from text import symbols
-
-
- def create_hparams(hparams_string=None, verbose=False):
- """Create model hyperparameters. Parse nondefault from given string."""
-
- hparams = tf.contrib.training.HParams(
- ################################
- # Experiment Parameters #
- ################################
- epochs=500,
- iters_per_checkpoint=500,
- seed=1234,
- dynamic_loss_scaling=True,
- fp16_run=False,
- distributed_run=False,
- dist_backend="nccl",
- dist_url="file://distributed.dpt",
- cudnn_enabled=True,
- cudnn_benchmark=False,
-
- ################################
- # Data Parameters #
- ################################
- training_files='ljs_audio_text_train_filelist.txt',
- validation_files='ljs_audio_text_val_filelist.txt',
- text_cleaners=['english_cleaners'],
- sort_by_length=False,
-
- ################################
- # Audio Parameters #
- ################################
- max_wav_value=32768.0,
- sampling_rate=22050,
- filter_length=1024,
- hop_length=256,
- win_length=1024,
- n_mel_channels=80,
- mel_fmin=0.0,
- mel_fmax=None, # if None, half the sampling rate
-
- ################################
- # Model Parameters #
- ################################
- n_symbols=len(symbols),
- symbols_embedding_dim=512,
-
- # Encoder parameters
- encoder_kernel_size=5,
- encoder_n_convolutions=3,
- encoder_embedding_dim=512,
-
- # Decoder parameters
- n_frames_per_step=1,
- decoder_rnn_dim=1024,
- prenet_dim=256,
- max_decoder_steps=1000,
- gate_threshold=0.6,
-
- # Attention parameters
- attention_rnn_dim=1024,
- attention_dim=128,
-
- # Location Layer parameters
- attention_location_n_filters=32,
- attention_location_kernel_size=31,
-
- # Mel-post processing network parameters
- postnet_embedding_dim=512,
- postnet_kernel_size=5,
- postnet_n_convolutions=5,
-
- ################################
- # Optimization Hyperparameters #
- ################################
- learning_rate=1e-3,
- weight_decay=1e-6,
- grad_clip_thresh=1,
- batch_size=48,
- mask_padding=False # set model's padded outputs to padded values
- )
-
- if hparams_string:
- tf.logging.info('Parsing command line hparams: %s', hparams_string)
- hparams.parse(hparams_string)
-
- if verbose:
- tf.logging.info('Final parsed hparams: %s', hparams.values())
-
- return hparams
|