|
@ -5,7 +5,7 @@ import torch |
|
|
|
|
|
|
|
|
def get_mask_from_lengths(lengths): |
|
|
def get_mask_from_lengths(lengths): |
|
|
max_len = torch.max(lengths) |
|
|
max_len = torch.max(lengths) |
|
|
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda() |
|
|
|
|
|
|
|
|
ids = torch.arange(0, max_len).long().cuda() |
|
|
mask = (ids < lengths.unsqueeze(1)).byte() |
|
|
mask = (ids < lengths.unsqueeze(1)).byte() |
|
|
return mask |
|
|
return mask |
|
|
|
|
|
|
|
|