diff --git a/train.py b/train.py index 4287016..b612b9a 100644 --- a/train.py +++ b/train.py @@ -56,10 +56,14 @@ def prepare_dataloaders(hparams): valset = TextMelLoader(hparams.validation_files, hparams) collate_fn = TextMelCollate(hparams.n_frames_per_step) - train_sampler = DistributedSampler(trainset) \ - if hparams.distributed_run else None + if hparams.distributed_run: + train_sampler = DistributedSampler(trainset) + shuffle = False + else: + train_sampler = None + shuffle = True - train_loader = DataLoader(trainset, num_workers=1, shuffle=True, + train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, sampler=train_sampler, batch_size=hparams.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn)