shakespeare mlp test run

jan 24 2026

Ok so at this point we have most of what we need train our own primitive language model.

Specifically, we’re going to build and train an MLP network on on the complete works of shakespeare, about 5 MB of text.

To get a baseline, I had claude cook up a quick ngram model for a baseline.

Because of what we said about overfitting earlier, the standard procedure here is to train the model on most of the text (80-90%) and save the rest for testing. We don’t want to test the model on stuff it’s already seen; pretty straightforward.

The results of the ngram are:

ngram dataset

Perplexity & Accuracy

n Train PPL Val PPL Train Acc Val Acc Contexts
2 12.35 12.29 26.0% 26.7% 100
3 6.93 7.87 40.4% 38.5% 2,341
4 4.92 5.88 49.0% 47.0% 20,634
5 3.99 5.46 55.1% 51.5% 98,493
6 3.48 6.21 59.3% 51.9% 309,366

5-gram is the sweet spot. At 6-gram, val perplexity increases (overfitting to training contexts).

Generation Samples

(5-gram, temperature=0.8)

“Enter”

Enter Caesar said the country many outward's scene II. And his dead; for that
with your him like not please in the done abused, thee are not from his power
Into France.

HORTENSIO.
This valoursed by the did so prattle of his about of men disdainter they'll
for throat we were is somethink your commonweal w

“Shall”

Shall of our some thou take thy life the was a man this one this,
Come, let sir, I am all seeks, and danger bottle, have doubless so in that
I am much,
Unless his but let me of him to my soul's since in since to be the thou the rams
Pink on the are so thee a cart good consieur; and such combine,
And, but

Notes

MLP time

Ok so the n-gram is pretty good. Time to see if we can beat it on shakespeare with our MLP.

And at this point the one thing we haven’t covered is how we encode text. For comparison with the n-gram, and since it’s a clear starting point, we’re going to predict text character by character.

To do this, we first get all the characters in the training set:

def create_vocab(corpus: list[Path], vocab_path: Path) -> dict[int, str]:
    char_set = set()
    for file_contents in corpus_iter(corpus):
        char_set.update(set(file_contents))

    vocab = {i: c for i, c in enumerate(list(char_set))}
    with open(vocab_path, 'w') as f:
        json.dump(vocab, f)
        print(f'saved vocab json to {vocab_path}')

    return vocab

So now each character is representable as a digit. The obvious problem is that if we have that \(a = 1\), \(b = 2\), and \(c = 3\), the relationship \(1+2=3\) is meaningful but in langauge-world, \(a+b=c\) is not.

The way to fix this is to give each character it’s own unique direction in an ‘embedding space’.

I.e. the \(j^{\mathrm{th}}\) character in our dataset corresponds to the one hot or standard basis vector \(e_j\) in \(\mathbb{R}^d\) where \(d:=\text{vocab size}\).

Theoretically, that’s all there is to it.

However, for practical (computational) reasons, we don’t actually use these one-hot encodings. We just use the character index, and feed them into an ‘embedding’ layer, before passing them into the rest of the network.

self.embedding = nn.Embedding(len(vocab), emb_dim)

In other words, the embedding layer has type \(\mathbb{Z} \to \mathbb{R}^d\), which is identical to converting to a one-hot encoding \(\{0,1\}^d\) and passing it through a linear layer with no bias. (no bias because \(xW +b\) where \(x\) is one-hot means you get \(W[i] +b\) for token \(i\), and this constant \(b\) can just be absorbed into the weight matrix)

The rest of the network is fairly straightforward, since we just have linear layers and relu to play with.

We’ll start with:

h_size = emb_dim * ctx_len
nn.Sequential(
    nn.Linear(h_size, h_size),
    nn.ReLU(),
    nn.Linear(h_size, len(vocab))
)

And the training loop is just boilerplate for our optimizer with some logging:

def train(model: nn.Module, 
          optimizer: optim.Optimizer,
          train_loader: DataLoader,
          val_loader: DataLoader,
          logger: logging.Logger,
          device: torch.device,
          batch_log_interval: int=0,
          batch_val_interval: int=0):

    log_train_info(model, train_loader, logger)
    loss_fn = nn.CrossEntropyLoss()
    model.to(device)
    model.train()

    n_batches = len(train_loader)
    
    if batch_log_interval == 0:
        batch_log_interval = n_batches // 10
    val_loss = 0

    if batch_val_interval == 0:
        batch_val_interval = batch_log_interval * 2

    for batch_no, batch in enumerate(train_loader):
        batch_t0 = time.perf_counter()
        batch = batch.to(device)
        x = batch[:, :-1]
        y = batch[:, -1]
        y_pred = model.forward(x)

        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_no+1) % batch_val_interval == 0:
            val_loss = validate(model, val_loader, device)

        if (batch_no+1) % batch_log_interval == 0:
            log_batch(logger, loss.item(), val_loss, batch_no + 1, batch_t0, train_loader)

    final_val_loss = validate(model, val_loader, device)
    logger.info(f"training finished w/ final val ppl={math.exp(final_val_loss)}\n\n")
    return final_val_loss

running the MLP

There are a few ‘hyperparameters’ that we haven’t set yet. We’ll just kind of pick random ones to start:

context_length=10   # how many chars does the network see at once
batch_size=1024     # how many samples do we compute at once on gpu
lr=0.01             # learning rate (step size) for sgd
momentum=0.9        # momentum coefficient for sgd
n_layers=2          # how many layers in our MLP
emb_dim=64          # output dim of the first implicit (embedding) layer
hidden_dim=64*10    # subsequent layer dimensions; (emb_dim * context_length)

Let’s run it and see how we do:

 batch 250/4711; 'train_loss': 2.617, 'train_ppl': 13.694, 'val_loss': 0, 'val_ppl': 1.0}
 batch 500/4711; 'train_loss': 2.262, 'train_ppl': 9.603, 'val_loss': 0, 'val_ppl': 1.0}
 batch 750/4711; 'train_loss': 2.215, 'train_ppl': 9.166, 'val_loss': 0, 'val_ppl': 1.0}
 batch 1000/4711; 'train_loss': 2.2, 'train_ppl': 9.025, 'val_loss': 0, 'val_ppl': 1.0}
 batch 1250/4711; 'train_loss': 2.133, 'train_ppl': 8.437, 'val_loss': 0, 'val_ppl': 1.0}
 batch 1500/4711; 'train_loss': 1.999, 'train_ppl': 7.384, 'val_loss': 0, 'val_ppl': 1.0}
 batch 1750/4711; 'train_loss': 2.04, 'train_ppl': 7.69, 'val_loss': 0, 'val_ppl': 1.0}
 batch 2000/4711; 'train_loss': 1.894, 'train_ppl': 6.649, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 2250/4711; 'train_loss': 1.919, 'train_ppl': 6.817, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 2500/4711; 'train_loss': 1.95, 'train_ppl': 7.031, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 2750/4711; 'train_loss': 1.869, 'train_ppl': 6.48, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 3000/4711; 'train_loss': 1.821, 'train_ppl': 6.176, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 3250/4711; 'train_loss': 1.779, 'train_ppl': 5.927, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 3500/4711;'train_loss': 1.802, 'train_ppl': 6.064, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 3750/4711; 'train_loss': 1.867, 'train_ppl': 6.472, 'val_loss': 2.039, 'val_ppl': 7.684}
 batch 4000/4711;71, 'train_loss': 1.732, 'train_ppl': 5.65, 'val_loss': 1.885, 'val_ppl': 6.585}
 batch 4250/4711;'train_loss': 1.879, 'train_ppl': 6.546, 'val_loss': 1.885, 'val_ppl': 6.585}
 batch 4500/4711; 'train_loss': 1.852, 'train_ppl': 6.376, 'val_loss': 1.885, 'val_ppl': 6.585}
 training finished w/ final val ppl=6.404950198477691

Not terrible! final perplexity of 6.4, so we’re close behind our highly inductively biased n-gram! And it’s worth noting our model is ingesting 10 characters at once; double the ‘context’ of the n-gram.

text generation:

One thing that’s worth noting before we look at the sample generations… There are many ways we could sample from the model output!

One way (“temperature zero”) is to just pick the biggest “logit” (thing that the model spits out at the end of the network before we softmax into probabilities for the loss function)

Another way is to scale up the logits by a constant in the softmax, and do multinomial sampling:

\[ \mathrm{softmax}(x,T) = \frac{e^{x_i / T}}{\sum e^{x_j/ T}} \]

where \(T\) refers to the ‘temperature’

sample generations, temperature = 0.2:

constrain and a stremant the with the stranger the word of the prother to the sure of the poor

her love to the words the words that have the word when and man the will be a the earth,well

shall not a strain the ward of the come..lord, and the with the are are the

Not bad! There’s some primitive broken grammar, legible words, etc…

Here’s a model generation for another run at a higher temperature:

“By fair and dorn-ooking.maye’d pergate,standed is to ming ahis fighty fropose,”

Also it’s worth noting that on this re-run, we finished with a val ppl of 6.33. a full tenth of a point better than our original run, even though nothing changed except the demo settings at the end… SGD is definitely stochastic!

training dynamics

One other thing to note is that we’re still improving on the validation perplexity… and one advantage we have over the n-gram model is we can just… train more times on the same corpus and keep feeding more gradient information to our optimizer.

Let’s try it with a few ‘epochs’ (iterations on the same training set):

After a few epochs, we get a final val perplexity of 5.39, beating the 5-gram baseline!

Sample generations:

the more than the court the common the too love the common the common and the lives a p

the common and the content to the state and the charge of the time of the comes the commo

the heart of the court the rest the court the part of the common the common the common a

Except … the generation quality really sucks.

hyperparameter adjustment

Now it’s time to adjust some hyperparameters and see if we can’t improve the baseline:

Adjustment #1: longer context window:

Result: final ppl of 5.53 after 4 epochs.

Sample generations:

by all men’s judgements,if the shall be for the langer of the wine of the hands..say the seart of

minds,therefore are the comes of faith,they have love the heart of the stay and the forthhave

Dost thou hear, Camillo,can the reason of the state of the heart of the field.III. The same. A

Better! except now the model has 4.4M params instead of 400k. Training takes longer, but not 10 times longer thanks to gpu i guess.

Adjustment #2: wider network

This time we set hidden_dim = h_size + 40 (arbitrary number)

Result: final ppl of 5.53 after 4 epochs.

Sample generations:

ot so.am as ignorant in that hath made of this from the sentle to the street,stard shall see his but a mounter str

basses, but oneamongst and so the death.TOBY.shall be my lord, and the shall be the stronged and str

ayers urgeth stillUnder what thou art thou shall be the willshe will be the hath of the world,therefore th

Adjustment #3: deeper network

This time we add another layer:

layers = [
    nn.Linear(h_size, h_size + k),
    nn.ReLU(),
    nn.Linear(h_size, h_size + k),
    nn.ReLU(),
    nn.Linear(h_size + k, len(vocab))
]

Result: final ppl of 5.13 after 4 epochs.

Sample generations:

rio, whom your gentle daughter here,there is the sealth of the man.II. The same. A Room in the Ca

us lawfully, or business;the emperor the course of the common the state the court of the state of the place of th

.—Fairest Emily,gods by the time of the course of the streets the present the heart of the present to the metter.[

Adjustment #3: wider and deeper network

this time we use the extra layer and have a wider layer, same as in #2.

Result: final ppl of 5.14 after 4 epochs.

Sample generations:

maculate, look on thy virgin;then the prince of the like and courtesy..am an and then, in

mportance, ’twerepiteous to the court of the wind of the world,he is not the common in the street on the like

ight.sighs resound through the rest of the state,when the higher that he will be the will have me, and therefore

Adjustment #4: layer gluttony

Ok so last result was really good what if we stack more layers and add more context?

context_length = 64
h_size = emb_dim * ctx_len
k = 80
layers = [
    nn.Linear(h_size, h_size + k),
    nn.ReLU(),
    nn.Linear(h_size + k, h_size + k),
    nn.ReLU(),
    nn.Linear(h_size + k, h_size + k),
    nn.ReLU(),
    nn.Linear(h_size + k, len(vocab))
]

It’s worth noting that the weight matrices scale quadratically… Each weight matrix is on the order of context_length * emb_dim squared parameters… so at these context lengths that’s ~52M parameters.

Result: final ppl of 5.22 after 4 epochs.

Sample generations (temperature = 0.5):

PROMPT:  " shall not come about her,\nAway with him, and let her sport hers"
OUTPUT:  ".\n\nCASSIUS.\nWho, you are do read and him with the danger.\n\nFIRST LORD.\nI thou didst thee with this not shall be even and of his wing'st there,\nAnd this shall shall be so much father's the devinest, an"

PROMPT:  "es,\n    So mild, that patience seemed to scorn his woes.\n\nIn him"
OUTPUT:  " that the spiton of this daughter,\nBut with the world the thrive the strong under\nAnd predition of all the legged them their pleased,\nStand the properations and his more the nature.\n\nGLOUCESTER.\nWhat "

PROMPT:  "sk the spotted princess how she fares.\n\nShe says her subjects wi"
OUTPUT:  "th the bright of her be my delivery.\n\nARGANDREY.\nI will not not him to the storest,\nAnd therefore commanded.\n\nPANDARUS.\nNo, my lord, the country time.\n\nGLOUD.\nNo, not thou didst you seek you that me t"

Sample generations (temperature = 0.2):

PROMPT:  " about, let them depart.\n\nLEONTES.\nProceed:\nNo foot shall stir.\n"
OUTPUT:  "\nALLIA.\nNow now, my lord.\n\nPATRO.\nWhy, that I have seen the street the greater the sure of the find.\n\nPRINCE.\nWhy, then the strange of the blood, and so the time.\n\nLUCENTIO.\nI have not the time of the"

PROMPT:  "her trust her.\nFor every inch of woman in the world,\nAy, every d"
OUTPUT:  "eath the sing of my such as the\nThat he shall be a head the stand of the street\nThe strange that where they shall be the street\nThat the should be the beggard the eyes shall be so himself,\nAnd bring t"

PROMPT:  "roving.\nMy will is strong, past reason's weak removing.\n    Who "
OUTPUT:  "should be so much the heart of me,\nAnd with the town of the with the strong of the world,\nAnd then the time of the world so be the street.\n\nBRUTUS.\nI have not the world be that he would not have thee "

Adjustment #5: adjust the learning rate

Setting lr=0.001 yields a final val ppl of 8.6 and substantially degraded output. Possible this is fixable by just letting it run longer, but training time is already long enough already for 5mb of text input.

Setting lr=0.1 gets validation perplexity down to 4.979 before overfitting all the way back to 6.01. But even then the generation quality is pretty good… Let’s check what happens if we stop before overfitting at 2 epochs:

Final val ppl of 4.91, and the sample generations are quite good!

main | INFO | 2026-01-24 17:12:38 | Sample generations (temperature = 0.5):
PROMPT:  ' your bare words.\n\nSILVIA.\nNo more, gentlemen, no more. Here com'
OUTPUT:  'es the sea and my father was\nwith you are the next. I shall not say,\nMy brace the name of soon to the great still;\nAnd I you of all hangs and the three the sad,\nAnd as my banished to the brought and g'

PROMPT:  'emove.”\n\n“Ay me,” quoth Venus, “young, and so unkind!\nWhat bare '
OUTPUT:  'what we strings with the great so fast.\n\nCAESAR.\nAy, and the siege, let him in the sun by me?\n\nBRUTUS.\nWhy shall the first which return you there?\n\nMARCUS.\nWhy, sir, I shall be how thee shall be that '

PROMPT:  'mercy, Theseus. ’Tis to me\nA thing as soon to die as thee to say'
OUTPUT:  ' to you.\n\nSECOND CITIZEN.\nMy lord, I will; and you well as I not gate.\n\n [_Exit Dost that will in a man is storn.\n\nKING EDWARD.\nWhat a fair within. He will me, there of the sun the wingst a court.\n\nEn'

The above output came from a smaller network, so let’s bump up the depth/width and try again.

main | INFO | 2026-01-24 17:17:22 | training finished w/ final val ppl=4.694544337616929

main | INFO | 2026-01-24 17:17:22 | Sample generations (temperature = 0.5):
PROMPT:  'res; he hears, and nods, and hums,\nAnd then cries “Rare!” and I '
OUTPUT:  'have not with thee.\n\nFALSTAFF.\nI am not the state of my master for this comes and the town and his humour\nof the man of bloody me the content\nTo thy other with her been before your first,\nTo be respec'

PROMPT:  'y words.\nHere, youth, there is my purse. I give thee this\nFor th'
OUTPUT:  'e town of the field of my horns\nTo the device is a fool, and man the glect of earth\n                                                                         \n\n                                         '

PROMPT:  'at mean. So, over that art\nWhich you say adds to nature, is an a'
OUTPUT:  'll man\nIn both offended to the cause of his horse,\nTo know not when the working of your service.\nCome, come, I will not for the kingly soul son\nBut her by the seas of life an every tears\nAs here, shal'

main | INFO | 2026-01-24 17:17:23 | Sample generations (temperature = 0.2):
PROMPT:  'ale weakness numbs each feeling part;     892\n  Like soldiers wh'
OUTPUT:  'en the stand of the world.\n\n [_Exeunt all but sound and the palace\nScene III. The same. The same of the particular the seas of the state\nof the state of the world shall be some person of the season an'

PROMPT:  ' mov’d him, was he such a storm\nAs oft ’twixt May and April is t'
OUTPUT:  'he world.\n\n [_Exit._]\n\nSCENE III. A Room in the Castle.\n\n Enter Antony and Athens.\n\nMARGARET.\nWhat news the shall be the seat of the state\nThat the state of the state of the state\nThat they are so sha'

PROMPT:  'olve you\nFor more amazement. If you can behold it,\nI’ll make the'
OUTPUT:  'e a man and the state of thee.\n\n[_Exit._]\n\n\n\n\nACT II\n\nSCENE I. Another part of the Capitol.\n\nEnter Mardinal.\n\nACHILLES.\nHow now, my lord.\n\n[_Exit Mardian._]\n\nEnter Marina.\n\nCASSIUS.\nI am not the work '

Final val of 4.7! And the generation quality is really good. Well. For a first attempt.

RECAP

It’s clear there’s more juice to squeeze out of this architecture with hyperparameter tweaks.

But honestly, it’s pretty clear we’re hitting a wall. It’s worth noting that it’s pretty impressive that “stack more layers” seems to work at the mlp level, and just choosing this simple architecture from ~first principles can actually beat an n-gram on perplexity. But the generation quality leaves more to be desired…

And from looking at the training logs, it’s clear we’re entering ‘overfitting’ territory. The model is starting to increase on validation perplexity towards the end of the runs around 4 epochs.

I think the obvious next thing to do is just to scale up the dataset to see if that doesn’t solve overfitting.

Then, once we start hitting walls on a bigger dataset, we can make architectural improvements that actually benefit from scale. In other words, there’s no reason to try improving the architecture on a 5.2mb corpus.

Next time: scale up, see what needs fixing, see what we can improve.