1class ConcatDataset(torch.utils.data.Dataset):
2 def __init__(self, *datasets):
3 self.datasets = datasets
4
5 def __getitem__(self, i):
6 return tuple(d[i] for d in self.datasets)
7
8 def __len__(self):
9 return min(len(d) for d in self.datasets)
10
11train_loader = torch.utils.data.DataLoader(
12 ConcatDataset( # concat
13 datasets.ImageFolder(traindir_A),
14 datasets.ImageFolder(traindir_B)
15 ),
16 batch_size=args.batch_size, shuffle=True,
17 num_workers=args.workers, pin_memory=True)
18
19for i, (input, target) in enumerate(train_loader):
20 ...
21