3.5. CIFAR10 CNN

Train a simple deep CNN on the CIFAR10 images dataset.

[1]:
import conx as cx
Using TensorFlow backend.
ConX, version 3.7.5
[2]:
net = cx.Network("CIRAR10")
net.add(cx.ImageLayer("input", (32, 32), 3))
net.add(cx.Conv2DLayer("conv1", 32, (3, 3), padding='same', activation='relu'))
net.add(cx.Conv2DLayer("conv2", 32, (3, 3), activation='relu'))
net.add(cx.MaxPool2DLayer("pool1", pool_size=(2, 2), dropout=0.25))
net.add(cx.Conv2DLayer("conv3", 64, (3, 3), padding='same', activation='relu'))
net.add(cx.Conv2DLayer("conv4", 64, (3, 3), activation='relu'))
net.add(cx.MaxPool2DLayer("pool2", pool_size=(2, 2), dropout=0.25))
net.add(cx.FlattenLayer("flatten"))
net.add(cx.Layer("hidden1", 512, activation='relu', vshape=(16, 32), dropout=0.5))
net.add(cx.Layer("output", 10, activation='softmax'))
net.connect()
[3]:
net.compile(error='categorical_crossentropy',
            optimizer='rmsprop', lr=0.0001, decay=1e-6)
[4]:
net.get_dataset("cifar10")
[5]:
net.dataset.info()

Dataset: CIFAR-10

Original source: https://www.cs.toronto.edu/~kriz/cifar.html

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class.

The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. “Automobile” includes sedans, SUVs, things of that sort. “Truck” includes only big trucks. Neither includes pickup trucks.

Information: * name : CIFAR-10 * length : 60000

Input Summary: * shape : (32, 32, 3) * range : (0.0, 1.0)

Target Summary: * shape : (10,) * range : (0.0, 1.0)

[6]:
net.dataset.chop(59000)
net.dataset.split(0.1)
net.dataset.split()
[6]:
(900, 100)

3.5.1. Examine Input as Image

[7]:
net.dataset.inputs.shape
[7]:
[(32, 32, 3)]
[8]:
net.dataset.targets.shape
[8]:
[(10,)]
[9]:
image = cx.array_to_image(net.dataset.inputs[0], scale=5.0)
image
[9]:
_images/CIFAR10_CNN_11_0.png
[10]:
net.dashboard()
[11]:
net.propagate(net.dataset.inputs[1])
[11]:
[0.09143757820129395,
 0.10652760416269302,
 0.08812205493450165,
 0.10210786759853363,
 0.11382994800806046,
 0.09381558746099472,
 0.09608996659517288,
 0.10617616027593613,
 0.10079711675643921,
 0.10109605640172958]
[12]:
net.train(5, batch_size=256)
_images/CIFAR10_CNN_14_0.svg
========================================================
       |  Training |  Training |  Validate |  Validate
Epochs |     Error |  Accuracy |     Error |  Accuracy
------ | --------- | --------- | --------- | ---------
#    5 |   2.25270 |   0.17778 |   2.23633 |   0.21000
[ ]: