A Toy Diffusion model you can run on your laptop

Thiago Lira
8 min readAug 4, 2022

It seems that it has become more and more common for niche new advancements in AI to break in to the mainstream media and to dazzle layman and practitioners alike with the amazing things one can achieve with Deep Learning.

The latest such occasion was OpenAI's amazing DALLE-2, the text to image model capable of generating an image for almost anything you can imagine, given that it can be described with text.

In this post I will code a toy model representing part of this amazing system. The Diffusion-based Decoder. I will focus on implementing it in Python and will try to use very little mathematical notation. This post is meant to help someone understand the technical aspects of the model, not the (amazing) mathematics behind it.

If you just want references and mathematical explanations of this model, I provide links on the next section. If you just want to see and run the code just go directly to my repo.

"Nietzsche and Jesus do a rap battle in the middle of a crowd" (Image generated with Dalle-2 by me)

Even working in the field sometimes I just have to convince myself something is really true by coding some version of it myself, albeit a very simple one. So here it is.

There are many excellent resources to learn about how Dalle-2 works. On this post I will focus on this one piece of mathematics that Dalle-2 is based upon called a diffusion model. Which is a model that learns to discern noise added to data (like which part of an image is blur) and surprisingly, can learn to remove this noise! And even more surprisingly, we can just use a diffusion model trained on this task of removing noise from images to create a new image from scratch by just removing noise from noise over and over, guided by some embedding (i.e. numerical representation) of what the image should look like. Again, this is quite involved and I think others resources can explain this process of creating and embeddings from text and converting it to a image embeddings much better than I can.

On the DALLE-2 pipeline this part is right at the end. Where the model has an image embedding and uses a diffusion based model to guide pure noise to become an image that represents this embedding, step by step.

What **is** a diffusion model?

All the rest of this post will be based upon the original proposal of diffusion models, by this work. And an improvement on the training objective proposed by this other work. If you want a simple yet more comprehensible deep dive in the mathematics behind this model, watch these excellent videos.

Our objective: Make a model learn to make a scatterplot that looks like Homer Simpson starting from pure noise!

Our initial data.

The big picture without much mathematical notation:

The whole process is composed of 2 steps:

a) We add T "steps" of noise to the data (e.g. multiply every point by a random number), and train the model try to predict how many steps T you added to the data, and do many backpropagation iterations with many many Ts. We will have a schedule of how much randomness we will add on every step, so that we can calculate each step directly from the data. Although we are using random numbers, data corrupted by 50 steps of noise will look alike other data corrupted by the same amount.

We select a maximum number of steps of noise you can add to the data. Normally something around 1000 (Let's call this number N). At this point we have some nice mathematical properties because we can assume that we turned our data into something indistinguishable from pure noise.

We will call this the forward process.

b) After the model is good at task a), you can give pure noise to the model and ask it to remove N steps of noise from it. Ideally the model will get the pure noise to look like the original data you used on the first task.

We will call this the reverse process.

A Mathematical Interlude to formalize things a little bit:

(You can skip this part if you want, I will try my best to use less mathematical notation in the following sections)

The forward and the reverse processes are Markov Chains.

The Forward Process

Given initial data X_0, we produce X_T (noise) by applying multiple Gaussian transitions (more on this on the next section). Each transition will apply a little Gaussian noise on X_t to generate X_t+1.

The reverse process is not so easy. We have some distribution p_theta that will do the opposite. We will try to find parameters (theta) to make this function gradually recover the data X_0 from pure noise X_T.

The Reverse Process

Every p and q will be defined to be Multivariate Gaussians. This is the topic of the next section. Our objective will be to learn the means for our reverse process distributions p!

The forward process:

We first need a function to add noise to an initial data distribution. Which basically mean that we will take all our data points and wiggle each one by a random number. What does that mean in Python? Let's first familiarize ourselves with a Multivariate Normal Distribution. Skip the next couple paragraphs if you already know how such a distribution works.

Let's start with a single point on the plane.

Say we want to take some point and randomly move it a little bit on the 2D plane. If the point is on coordinates (x1,y1), we sample 2 small random numbers πœ€_1 and πœ€_2 and we move the point to it's new location (x1 + πœ€_1, x2 + πœ€_2).

A Multivariate Normal Distribution is one such distribution to sample as many πœ€'s as we have data points. We will use our starting points as the mean of the distribution, which means (haha) that for every point (x_n, y_n) it's new position will be something like (x_n + πœ€_1, y_n + πœ€_2). And where does the πœ€ comes from? From our covariance matrix, which is a fancy name to a big matrix that will have one variance for each point in our distribution, i.e. how small or big each πœ€ will likely be.

The first function will just take our data, sample πœ€'s for each point and return the resulting data corrupted by those πœ€β€™s.

This list_bar_alphas argument is just the amount of noise we will add for every step T of diffusion, which is the other argument of this function. For some t we will get the corresponding alpha from list_bar_alphas, and use this alpha to calculate the variance of our noise. For every step of diffusion we fix the parameters of the noise added to the model beforehand! These values came from the original paper.

Now let's see some actual data. Say we want to diffuse this scatter plot of Homer Simpson. This is how the plot will look if we apply diffusion for some numbers t of steps.

See how the more steps of noise we add to poor Homer, the more he looks like pure noise.

Calculating the Posterior Distribution:

This part is going to get a bit technical on the mathematical details. The code is simple to read but there is indeed a lot going on theory-wise. We have a forward process corrupting the data, which in mathematical notation is something like this:

For every step t we will sample a more corrupted sample from this distribution. This is a Markov Chain generating new corrupted X’s from the original data X_0

This distribution is fixed. For any t we know exactly the distribution to sample X_t+1 from. But remember I wrote that we will try for each step to predict how much noise was added, right? So, and stay with me here, we want to predict the mean of the posterior distribution of q. In other words, the center, or the original data we had, like if we are walking backwards and removing the noise.

Pay attention to the indexes! Our model will try to predict the mean of this distribution.

This is not the reverse process! If q is the distribution we parametrized to corrupt the data, p will be the distribution to recover the data, our model! And the clever mathematical insight from the papers I linked before is that we can assume that our reverse distribution will have the same functional form of the forward process. Which we defined to be simple! So it is enough for our model to be some function that tries to match the output of the posterior of q during training. And this is exactly what is happening in the training loop:

The training loop, reproduced here in it's entirety.

Something important to understand, our posterior distribution q(x_t-1 | x_t) of the forward process is known! We just need some algebraic manipulation to calculate it. What is definitely not known is the other distribution p, the reverse process. It is easy to corrupt the data, it is very much not easy at all to get something like the data back from pure noise. The insight that I will not try to formalize here is that the functional form of the reverse process closely follows q(x_t-1 | x_t), which we can easily calculate. So we can assume it to be Gaussian as well. (*)

(*) For the mathematically keen among you, the paper indeed uses a KL Divergence term to approximate the reverse process distribution p to the known forward process distribution q. And it is from some clever manipulation that we get a simple loss target from that.

The Reverse Process:

Now for the fun part: We trained a model very good at discerning the noise from a corrupted sample produced by a known Gaussian distribution. Let's give it some random noise and see if it finds something resembling our traning data in it! We will run a loop, starting at random noise (t=N), to calculate the data at step t-1 until we supposedly should have the original data (t=0).

This is the function that does that inside this loop, we are just sampling from a Multivariate Normal Distribution with the next mean for the process predicted by the model, like I explained in the last section. Our model gives just the mean, the variance we define to be exactly the same from the forward process!

Results:

So starting with a 2D scatter plot of pure Gaussian noise I animated the reverse diffusion process the model learned! Here is this process in it's full GIF glory. Not bad, right?

The model finds something similar to the original data in pure Gaussian noise!
Our initial data.

If you want to play with the code, just clone my repo and run every cell on the notebook RunDiffusion.ipynb!

--

--