diff --git a/utils.py b/utils.py index c843d95..4395201 100644 --- a/utils.py +++ b/utils.py @@ -6,7 +6,7 @@ import torch def get_mask_from_lengths(lengths): max_len = torch.max(lengths).item() ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) - mask = (ids < lengths.unsqueeze(1)).byte() + mask = (ids < lengths.unsqueeze(1)).bool() return mask