how to load cifar10 in python

Solutions on MaxInterview for how to load cifar10 in python by the best coders in the world

showing results for - "how to load cifar10 in python"
Daniel
10 Sep 2016
1"""Load from /home/USER/data/cifar10 or elsewhere; download if missing."""
2
3import tarfile
4import os
5from urllib.request import urlretrieve
6import numpy as np
7
8
9def cifar10(path=None):
10    r"""Return (train_images, train_labels, test_images, test_labels).
11
12    Args:
13        path (str): Directory containing CIFAR-10. Default is
14            /home/USER/data/cifar10 or C:\Users\USER\data\cifar10.
15            Create if nonexistant. Download CIFAR-10 if missing.
16
17    Returns:
18        Tuple of (train_images, train_labels, test_images, test_labels), each
19            a matrix. Rows are examples. Columns of images are pixel values,
20            with the order (red -> blue -> green). Columns of labels are a
21            onehot encoding of the correct class.
22    """
23    url = 'https://www.cs.toronto.edu/~kriz/'
24    tar = 'cifar-10-binary.tar.gz'
25    files = ['cifar-10-batches-bin/data_batch_1.bin',
26             'cifar-10-batches-bin/data_batch_2.bin',
27             'cifar-10-batches-bin/data_batch_3.bin',
28             'cifar-10-batches-bin/data_batch_4.bin',
29             'cifar-10-batches-bin/data_batch_5.bin',
30             'cifar-10-batches-bin/test_batch.bin']
31
32    if path is None:
33        # Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist
34        path = os.path.join(os.path.expanduser('~'), 'data', 'cifar10')
35
36    # Create path if it doesn't exist
37    os.makedirs(path, exist_ok=True)
38
39    # Download tarfile if missing
40    if tar not in os.listdir(path):
41        urlretrieve(''.join((url, tar)), os.path.join(path, tar))
42        print("Downloaded %s to %s" % (tar, path))
43
44    # Load data from tarfile
45    with tarfile.open(os.path.join(path, tar)) as tar_object:
46        # Each file contains 10,000 color images and 10,000 labels
47        fsize = 10000 * (32 * 32 * 3) + 10000
48
49        # There are 6 files (5 train and 1 test)
50        buffr = np.zeros(fsize * 6, dtype='uint8')
51
52        # Get members of tar corresponding to data files
53        # -- The tar contains README's and other extraneous stuff
54        members = [file for file in tar_object if file.name in files]
55
56        # Sort those members by name
57        # -- Ensures we load train data in the proper order
58        # -- Ensures that test data is the last file in the list
59        members.sort(key=lambda member: member.name)
60
61        # Extract data from members
62        for i, member in enumerate(members):
63            # Get member as a file object
64            f = tar_object.extractfile(member)
65            # Read bytes from that file object into buffr
66            buffr[i * fsize:(i + 1) * fsize] = np.frombuffer(f.read(), 'B')
67
68    # Parse data from buffer
69    # -- Examples are in chunks of 3,073 bytes
70    # -- First byte of each chunk is the label
71    # -- Next 32 * 32 * 3 = 3,072 bytes are its corresponding image
72
73    # Labels are the first byte of every chunk
74    labels = buffr[::3073]
75
76    # Pixels are everything remaining after we delete the labels
77    pixels = np.delete(buffr, np.arange(0, buffr.size, 3073))
78    images = pixels.reshape(-1, 3072).astype('float32') / 255
79
80    # Split into train and test
81    train_images, test_images = images[:50000], images[50000:]
82    train_labels, test_labels = labels[:50000], labels[50000:]
83
84    def _onehot(integer_labels):
85        """Return matrix whose rows are onehot encodings of integers."""
86        n_rows = len(integer_labels)
87        n_cols = integer_labels.max() + 1
88        onehot = np.zeros((n_rows, n_cols), dtype='uint8')
89        onehot[np.arange(n_rows), integer_labels] = 1
90        return onehot
91
92    return train_images, _onehot(train_labels), \
93        test_images, _onehot(test_labels)