diff --git a/train.py b/train.py index 1035d5c..4287016 100644 --- a/train.py +++ b/train.py @@ -59,7 +59,7 @@ def prepare_dataloaders(hparams): train_sampler = DistributedSampler(trainset) \ if hparams.distributed_run else None - train_loader = DataLoader(trainset, num_workers=1, shuffle=False, + train_loader = DataLoader(trainset, num_workers=1, shuffle=True, sampler=train_sampler, batch_size=hparams.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn)