3.5. CIFAR10 CNN¶
Train a simple deep CNN on the CIFAR10 images dataset.
In [1]:
import conx as cx
Using TensorFlow backend.
Conx, version 3.6.0
In [3]:
net = cx.Network("CIRAR10")
net.add(cx.ImageLayer("input", (32, 32), 3))
net.add(cx.Conv2DLayer("conv1", 32, (3, 3), padding='same', activation='relu'))
net.add(cx.Conv2DLayer("conv2", 32, (3, 3), activation='relu'))
net.add(cx.MaxPool2DLayer("pool1", pool_size=(2, 2), dropout=0.25))
net.add(cx.Conv2DLayer("conv3", 64, (3, 3), padding='same', activation='relu'))
net.add(cx.Conv2DLayer("conv4", 64, (3, 3), activation='relu'))
net.add(cx.MaxPool2DLayer("pool2", pool_size=(2, 2), dropout=0.25))
net.add(cx.FlattenLayer("flatten"))
net.add(cx.Layer("hidden1", 512, activation='relu', vshape=(16, 32), dropout=0.5))
net.add(cx.Layer("output", 10, activation='softmax'))
net.connect()
In [5]:
net.compile(error='categorical_crossentropy',
optimizer='rmsprop', lr=0.0001, decay=1e-6)
In [25]:
net.dataset.get("cifar10")
In [26]:
net.dataset.info()
Dataset name: CIFAR-10
Original source: https://www.cs.toronto.edu/~kriz/cifar.html
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class.
The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. “Automobile” includes sedans, SUVs, things of that sort. “Truck” includes only big trucks. Neither includes pickup trucks.
Dataset Split: * training : 54000 * testing : 6000 * total : 60000
Input Summary: * shape : (32, 32, 3) * range : (0.0, 1.0)
Target Summary: * shape : (10,) * range : (0.0, 1.0)
In [27]:
net.dataset.chop(59000)
net.dataset.split(0.1)
net.dataset.split()
WARNING: dataset split reset to 0
Out[27]:
(900, 100)
3.5.1. Examine Input as Image¶
In [28]:
net.dataset.inputs.shape
Out[28]:
[(32, 32, 3)]
In [29]:
net.dataset.targets.shape
Out[29]:
[(10,)]
In [30]:
image = cx.array_to_image(net.dataset.inputs[0], scale=5.0)
image
Out[30]:
In [31]:
net.dashboard()
In [32]:
net.propagate(net.dataset.inputs[1])
Out[32]:
[0.07045340538024902,
0.3037354350090027,
0.04076571762561798,
0.042184844613075256,
0.01479960698634386,
0.03623291105031967,
0.020028751343488693,
0.03102997876703739,
0.20176030695438385,
0.23900897800922394]
In [33]:
net.train(5, batch_size=256)
========================================================
| Training | Training | Validate | Validate
Epochs | Error | Accuracy | Error | Accuracy
------ | --------- | --------- | --------- | ---------
# 8 | 1.78160 | 0.34889 | 1.68297 | 0.39000