"""Load from /home/USER/data/cifar10 or elsewhere; download if missing."""
import tarfile
import os
from urllib.request import urlretrieve
import numpy as np
def cifar10(path=None):
r"""Return (train_images, train_labels, test_images, test_labels).
Args:
path (str): Directory containing CIFAR-10. Default is
/home/USER/data/cifar10 or C:\Users\USER\data\cifar10.
Create if nonexistant. Download CIFAR-10 if missing.
Returns:
Tuple of (train_images, train_labels, test_images, test_labels), each
a matrix. Rows are examples. Columns of images are pixel values,
with the order (red -> blue -> green). Columns of labels are a
onehot encoding of the correct class.
"""
url = 'https://www.cs.toronto.edu/~kriz/'
tar = 'cifar-10-binary.tar.gz'
files = ['cifar-10-batches-bin/data_batch_1.bin',
'cifar-10-batches-bin/data_batch_2.bin',
'cifar-10-batches-bin/data_batch_3.bin',
'cifar-10-batches-bin/data_batch_4.bin',
'cifar-10-batches-bin/data_batch_5.bin',
'cifar-10-batches-bin/test_batch.bin']
if path is None:
path = os.path.join(os.path.expanduser('~'), 'data', 'cifar10')
os.makedirs(path, exist_ok=True)
if tar not in os.listdir(path):
urlretrieve(''.join((url, tar)), os.path.join(path, tar))
print("Downloaded %s to %s" % (tar, path))
with tarfile.open(os.path.join(path, tar)) as tar_object:
fsize = 10000 * (32 * 32 * 3) + 10000
buffr = np.zeros(fsize * 6, dtype='uint8')
members = [file for file in tar_object if file.name in files]
members.sort(key=lambda member: member.name)
for i, member in enumerate(members):
f = tar_object.extractfile(member)
buffr[i * fsize:(i + 1) * fsize] = np.frombuffer(f.read(), 'B')
labels = buffr[::3073]
pixels = np.delete(buffr, np.arange(0, buffr.size, 3073))
images = pixels.reshape(-1, 3072).astype('float32') / 255
train_images, test_images = images[:50000], images[50000:]
train_labels, test_labels = labels[:50000], labels[50000:]
def _onehot(integer_labels):
"""Return matrix whose rows are onehot encodings of integers."""
n_rows = len(integer_labels)
n_cols = integer_labels.max() + 1
onehot = np.zeros((n_rows, n_cols), dtype='uint8')
onehot[np.arange(n_rows), integer_labels] = 1
return onehot
return train_images, _onehot(train_labels), \
test_images, _onehot(test_labels)