# 3.6. Auto-encoder with Conv2D¶

This notebook demonstrates using a Conv2D network in an autoencoding task with the MNIST dataset.

[1]:

import conx as cx

Using TensorFlow backend.
ConX, version 3.7.5


First, we make a network. We will work with a 3 dimensional input from MNIST, but a flat target vector.

[2]:

net = cx.Network("Auto-Encoding with Conv")
cx.Conv2DLayer("Conv2D-1", 16, (5,5), colormap="gray", activation="relu"),
cx.MaxPool2DLayer("maxpool1", (2,2)),
cx.Conv2DLayer("Conv2D-2", 132, (5,5), activation="relu"),
cx.MaxPool2DLayer("maxpool2", (2,2)),
cx.FlattenLayer("flatten"))
net.add(cx.Layer("output", 28 * 28, vshape=(28,28), activation='sigmoid'))
net.connect()

[3]:

net.compile(error="mse", optimizer="adam")


We load the MNIST dataset and examine the shapes of the inputs and targets.

[4]:

net.get_dataset("mnist")

[5]:

net.dataset.info()


Dataset: MNIST

Original source: http://yann.lecun.com/exdb/mnist/

The MNIST dataset contains 70,000 images of handwritten digits (zero to nine) that have been size-normalized and centered in a square grid of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers representing grayscale intensities ranging from 0 (black) to 1 (white). The target data consists of one-hot binary vectors of size 10, corresponding to the digit classification categories zero through nine. Some example MNIST images are shown below:

Information: * name : MNIST * length : 70000

Input Summary: * shape : (28, 28, 1) * range : (0.0, 1.0)

Target Summary: * shape : (10,) * range : (0.0, 1.0)

Because this is an auto-encoding task, we wish that the targets were the same as the inputs:

[6]:

net.dataset.set_targets_from_inputs()

WARNING: network 'Auto-Encoding with Conv' target bank #0 has a multi-dimensional shape; is this correct?


However, that gives a warning. ConX does not allow targets to have a shape, so we flatten them:

[7]:

net.dataset.targets.reshape(28 * 28)

[8]:

net.dataset.targets.shape

[8]:

[(784,)]

[9]:

net.dashboard()


Just to test our design, we chop the majority of patterns, leaving only 100.

[10]:

net.dataset.chop(69900)


And save 10% for testing/validation:

[11]:

net.dataset.split(0.1)

[12]:

net.reset()
net.train(50)

========================================================
|  Training |  Training |  Validate |  Validate
Epochs |     Error |  Accuracy |     Error |  Accuracy
------ | --------- | --------- | --------- | ---------
#   50 |   0.02027 |   0.00000 |   0.02228 |   0.00000

[ ]: