1from torch.nn.utils.rnn import pack_sequence
2from torch.utils.data import DataLoader
3
4def my_collate(batch):
5 # batch contains a list of tuples of structure (sequence, target)
6 data = [item[0] for item in batch]
7 data = pack_sequence(data, enforce_sorted=False)
8 targets = [item[1] for item in batch]
9 return [data, targets]
10
11# ...
12# later in you code, when you define you DataLoader - use the custom collate function
13loader = DataLoader(dataset,
14 batch_size,
15 shuffle,
16 collate_fn=my_collate, # use custom collate function here
17 pin_memory=True)
18