3.12. Alice in Wonderland

This notebook demonstrates generating sequences using a Simple Recurrent Network (SimpleRNN).

For this example, we will use the unprocessed text from Lewis Carroll’s “Alice in Wonderland”. However, the sequence can really be anything, including code, music, or knitting instructions.

[1]:
import conx as cx
Using TensorFlow backend.
ConX, version 3.7.5

First, we find a copy of Alice in Wonderland, download it, and read it in:

[2]:
INPUT_FILE = "alice_in_wonderland.txt"
[3]:
cx.download("http://www.gutenberg.org/files/11/11-0.txt", filename=INPUT_FILE)
Using cached http://www.gutenberg.org/files/11/11-0.txt as './alice_in_wonderland.txt'.
[4]:
# extract the input as a stream of characters
lines = []
with open(INPUT_FILE, 'rb') as fp:
    for line in fp:
        line = line.strip().lower()
        line = line.decode("ascii", "ignore")
        if len(line) == 0:
            continue
        lines.append(line)
text = " ".join(lines)
lines = None # clean up memory

Next, we create some utility dictionaries for mapping the characters to indices and back:

[5]:
chars = set([c for c in text])
nb_chars = len(chars)
char2index = dict((c, i) for i, c in enumerate(chars))
index2char = dict((i, c) for i, c in enumerate(chars))
[6]:
nb_chars
[6]:
55

In this text, there are 55 different characters.

Each character has a unique mapping to an integer:

[7]:
char2index["a"]
[7]:
37
[8]:
index2char[5]
[8]:
'9'

3.12.1. Build the Dataset

Next we build the dataset. We do this by stepping through the text one character at a time, building an input
sequence the size of SEQLEN and associated target character.

For example, assume an input sequence of “the sky was falling”, we would get the following inputs and targets:

Inputs     -> Target
----------    ------
the sky wa -> s
he sky was ->
e sky was  -> f
 sky was f -> a
sky was fa -> l

How can we represent the characters? There are many ways, including using an EmbeddingLayer. In this example, we simply use a onehot encoding of the index. Note that the total length of the onehot encoding is one more than the total number of items. That is because we will use a position for the zero index as well.

[9]:
SEQLEN = 10
data = []
for i in range(0, len(text) - SEQLEN):
    inputs = [cx.onehot(char2index[char], nb_chars + 1) for char in text[i:i + SEQLEN]]
    targets = [cx.onehot(char2index[char], nb_chars + 1) for char in text[i + SEQLEN]][0]
    data.append([inputs, targets])
text = None # clean up memory
[10]:
dataset = cx.Dataset()
dataset.load(data)
data = None # clean up memory; not needed
[11]:
len(dataset)
[11]:
158773
[12]:
cx.shape(dataset.inputs[0])
[12]:
(10, 56)

The shape of the inputs is 10 x 56; a sequence of length 10, and a vector of length 56.

Let’s check the inputs and targets to make sure everything is encoded properly:

[13]:
def onehot_to_char(vector):
    index = cx.argmax(vector)
    return index2char[index]
[14]:
for i in range(10):
    print("".join([onehot_to_char(v) for v in dataset.inputs[i]]),
          "->",
          onehot_to_char(dataset.targets[i]))
project gu -> t
roject gut -> e
oject gute -> n
ject guten -> b
ect gutenb -> e
ct gutenbe -> r
t gutenber -> g
 gutenberg -> s
gutenbergs ->
utenbergs  -> a

Looks good!

3.12.2. Build the Network

We will use a single SimpleRNNLayer with a fully-connected output bank to compute the most likely predicted output character.

Note that we can use the categorical cross-entropy error function since we are using the “softmax” activation function on the output layer.

In this example, we unroll the inputs to provide explicit weights between each character in the sequence and the output.

[15]:
network = cx.Network("Alice in Wonderland")
network.add(
    cx.Layer("input", (SEQLEN, nb_chars + 1)),
    cx.SimpleRNNLayer("rnn", 128,
                      return_sequences=False,
                      unroll=True),
    cx.Layer("output", nb_chars + 1, activation="softmax"),
)
network.connect()
network.compile(error="categorical_crossentropy", optimizer="rmsprop")
[16]:
network.set_dataset(dataset)
[17]:
network.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input (InputLayer)           (None, 10, 56)            0
_________________________________________________________________
rnn (SimpleRNN)              (None, 128)               23680
_________________________________________________________________
output (Dense)               (None, 56)                7224
=================================================================
Total params: 30,904
Trainable params: 30,904
Non-trainable params: 0
_________________________________________________________________
[18]:
network.dashboard()

3.12.3. Train the Network

After each training epoch we will test the generated output.

We could use cx.choice(p=output) or cx.argmax(output) for picking the next character. Which works best for you?

[19]:
def generate_text(sequence, count):
    for i in range(count):
        output = network.propagate(sequence)
        char = index2char[cx.argmax(output)]
        print(char, end="")
        sequence = sequence[1:] + [output]
    print()
[20]:
network.reset()
[21]:
for iteration in range(25):
    print("=" * 50)
    print("Iteration #: %d" % (network.epoch_count))
    results = network.train(1, batch_size=128, plot=False, verbose=0)
    sequence = network.dataset.inputs[cx.choice(len(network.dataset))]
    print("Generating from seed: '%s'" % ("".join([onehot_to_char(v) for v in sequence])))
    print("".join([onehot_to_char(v) for v in sequence]), end="")
    generate_text(sequence, 100)
network.plot_results()
==================================================
Iteration #: 0
Generating from seed: 'ent on eag'
ent on eag tt t
==================================================
Iteration #: 1
Generating from seed: 'ging her s'
ging her she e  i    i t  t t    t t  i t    t t  i t    t t  i t t  t t  i t t  t t    t t  i t    t t  i t
==================================================
Iteration #: 2
Generating from seed: 'of rock, a'
of rock, and a  t s  a t  t t  t t  t t    t t  t t  t t  t t  t t    t t  t t  t t  t t  t t    t t  t t  t t
==================================================
Iteration #: 3
Generating from seed: 'estriction'
estriction  to tet it   t  t t t  t t  t t  t t t  t t  t t tt t t  t t  t t tt t t  t t  t t tt t t  t t  t t
==================================================
Iteration #: 4
Generating from seed: 'nd began s'
nd began she  ase tol  o    t  o t to t t  a t  a t  t t  t t  a t  a t  a t  a t  a t  a t  a t  a t  a t  a
==================================================
Iteration #: 5
Generating from seed: 'ainst acce'
ainst acce site eit te e ae t ee t  s et t es e et e  i e  i e  e e te e ee e  e es t es e es e et e et e  i e
==================================================
Iteration #: 6
Generating from seed: 'ce, scream'
ce, screame t t t  e ti e as t e  t te t  e to e tit  a t  t ai t ee to e ti e a te t te t me t te t t  t t  t
==================================================
Iteration #: 7
Generating from seed: ' it seemed'
 it seemed to hiee tile ti e ai e te aile aite tite ai e at eite aite tite ai e at aile aite aite ai e te aite
==================================================
Iteration #: 8
Generating from seed: 'lice (she '
lice (she waste  o ee ee t ee tile ao ees  t le tele s at tele as e et tire a lees at tere atlees at ee tele t
==================================================
Iteration #: 9
Generating from seed: 's; these w'
s; these whtl  ie to te te te to te to tert tt tee  i t  te te te  o teete to tort  te tert i  tee so  eete to
==================================================
Iteration #: 10
Generating from seed: 'y never ex'
y never exelaten to tome t et seat so s es to s ts ese  on ire  ti sece se s oe t  ee t to eees t   eetts   ot
==================================================
Iteration #: 11
Generating from seed: 'but i dont'
but i dont kent tertsle  ioe te ti e aret ee tt mile sere ie tioe tr tiee tr t ee trrt et ti tice sort te ts e
==================================================
Iteration #: 12
Generating from seed: ' remedies-'
 remedies-- h li tors  o tied tr sont ti toe  il  eet se  se  ast te es t et the  te see  on tom  en sie   t
==================================================
Iteration #: 13
Generating from seed: 'hed it on '
hed it on a mis  ailte etten do teee to teet the eea eteit te seee i  t et es the eeat terte ae ie e tt to t t
==================================================
Iteration #: 14
Generating from seed: 'he queen w'
he queen wat e ete  i t  tn sel  tetee ert  oll ore tar e  tt se s a  erte ttee tires ot  oe  orte iee  ot toe
==================================================
Iteration #: 15
Generating from seed: 'hat makes '
hat makes trem  otteet sor siats fot io  set et tols tees toa se asles frrtee tor  es siek ioa ee to sees tre
==================================================
Iteration #: 16
Generating from seed: ' on the ta'
 on the table  fal orlt getten  t ine tor  eelt neate  it wrre iurt neet celt selt dlen i et wore trrted it wo
==================================================
Iteration #: 17
Generating from seed: ' this busi'
 this busine st es er t ot oo teles oe  ome, on eeteteoe tol ol pt ts p  e tece olltese tt ee e te sese tore
==================================================
Iteration #: 18
Generating from seed: ' about her'
 about her fiot   i  sot mas soete te  ilees oole otee to e ot tise son s ot turedenes tol soot  oe e ot wise
==================================================
Iteration #: 19
Generating from seed: 'all her co'
all her coul tted te te   i  t tlite flit orett ee te ai t etele te te ert  yet  enet to tor  ton in whe corti
==================================================
Iteration #: 20
Generating from seed: 'sting date'
sting date  tt d et tot  ouce tini lint  eats et tore toete oel ot meet  ou  tite tol  tal ite s on tor  anter
==================================================
Iteration #: 21
Generating from seed: 'oprietary '
oprietary iast ten e ee  o li ler forts or ilets oes to  ine  eet  o  orl tee io mested til  i at tea  orr tte
==================================================
Iteration #: 22
Generating from seed: 'ly remarke'
ly remarked arr  aine tott  e t lend out  te tel  t ol tire tenes i tint, and then ire tori  o leran of  teot
==================================================
Iteration #: 23
Generating from seed: 'ttle now a'
ttle now as thement te t  sot oo to sons outt netten a tore tor, tor lot se tile o r  licted ot oret tortee af
==================================================
Iteration #: 24
Generating from seed: 'ion in the'
ion in the somsst on  omt ftre, oos tt sott oi, to some ting  on tore to d. ther ant oult  en  oos  ot sor lon
_images/AliceInWonderland_30_1.png

What can you say about the text generated in later epochs compared to the earlier generated text?

This was the simplest and most straightforward of network architectures and parameter settings. Can you do better? Can you generate text that is better English, or even text that captures the style of Lewis Carroll?

Next, you might like to try this kind of experiment on your own sequential data.