Machine learning is, we might say loosely, a one-trick pony: we can find functions that minimize a loss when evaluated on a set of data. I’ll call this trick the “imitation machine”, since that’s usually how it’s used: the loss is computed by comparing the function’s output with some reference output (we’ll see some important exceptions though). But machine learning scholars are creative: we can use that one trick to do all kinds of things.
Language Modeling
I’ve written a lot about LM on this blog, so I won’t bore you here. Let’s just say:
- We can make functions that output discrete probability distributions (using softmax).
- We can train them to minimize perplexity of sequences from naturalistic language (by minimizing cross-entropy loss for next-token prediction).
- We can generate new text by sampling from that distribution one token at a time, appending it and repeating.
Thus we can use an imitation machine to make a language parrot.
What I didn’t appreciate until recently was just how many other tasks this can end up learning. For example, it’s long been clear that this ends up learning syntactic analysis (because that helps predict whether, say, a verb is singular or plural). But it wasn’t obvious that this would learn translation until GPT-2 showed that (because documents often include glosses in a second language). And while most natural text corpora don’t contain a lot of the sort of instruction-response dialogues that characterize the modern chatbot (e.g., ChatGPT or Bard), InstructGPT showed us that the instruction is just a minor variation on the sort of context that these models do see often in training data, so a small amount of fine-tuning is sufficient to get them to recognize instructions as another way to cue that context.
Reinforcement Learning
The “imitation machine” framing assumes that we know what the right result should be. But that’s a big assumption. Examples:
- When you’re playing a video game, you don’t get a reward for moving towards the gold coin, only for reaching it. In fact, you may temporarily incur a cost to get that reward.
- When a robot is picking up an object to move it, it doesn’t get feedback every millisecond about whether it’s being successful, only when it puts the object down at the end.
- When playing chess, you may sacrifice a queen in order to get checkmate two moves later.
That’s some classic examples, but here’s another one that’s become important recently:
- When a dialogue agent is generating text, it doesn’t get feedback after every single token whether that response is useful to the user, only after generating potentially thousands of tokens.
The common element of all of these situations is delayed reward: we don’t get immediate feedback on what the right action was.
The clever trick of practical Reinforcement Learning algorithms, though, is, you guessed it: turn this into an imitation problem. Here are some common approaches (oversimplified):
- Make an imitation model that attempts to predict later rewards from current states and actions. Use that model to give a fake reward for each action.
- Make an imitation model that includes both actions and rewards, and just ask it to imagine what might come next.
- Just give a lot of examples of successful attempts (including ones where some difficulty was encountered), and use a language model.
Image generation
How do you generate an image using an imitation machine?
- Pretend it’s language. Organize the image into a sequence of tokens (e.g., scan pixels left-to-right, row by row; or make groups of, say, 8 pixels square and give common groups a “token id”). Then sample one token at a time. ImageGPT, Taming Transformers, etc.
- Learn a function from noise to image. (Noise is great because you can sample it, so you can generate many different images by starting with different noise samples.) There are two basic approaches to do this, differing in what they do about the loss.
- Diffusion: Learn to map a noisy image to a less-noisy image. Start with an image, add noise, and give the noised image as input and penalize the difference between the output and the starting image. The trick is repeating this for different levels of noise; one you add noise many times, you get a totally noisy image, i.e., all images look alike with enough noise added. This is the approach taken by DALL-E 2, Stable Diffusion, and other popular image generators. Major con: speed. You have to start with noise and generate many successively less noisy images sequentially.
- Generative Adversarial Networks: These are actually older than diffusion models. The basic idea is: learn a one-step mapping (the generator) from noise to image1, and simultaneously learn a discriminator that achieves low loss when it can correctly label whether an image is generated or real.
- The training regime is a bit tricky: we alternate training the discriminator (to reduce its loss) and training the generator (to increase the discriminator’s loss, i.e., to make it generate images that are more difficult to distinguish).
- When we backpropagate through the discriminator, we get an image-shaped forgery detection: each pixel comes with a value of how much changing that pixel would cause the discriminator to get more confused. This gives the generator a pixel-by-pixel target to achieve, analogous to the denoised image from the diffusion model.
- We need to make sure the two networks are balanced in training: if the discriminator gets too much better than the generator, there are no pixel-by-pixel changes that would meaningfully charge its judgments, so the generator can’t learn.
Other
We could also discuss clustering, inferential stats, and even causal inference as other clever applications of the “imitation machine” to useful tasks.
On generalization
A table-lookup function could be an imitation machine, but it would perform poorly on data it hasn’t seen, and more practically, it’s brittle: mess up the table a bit and the loss spikes even on the training data.2 But happily we’ve found many types of functions that are much less brittle because they spread out the work of computing the output into a sequence of information-processing operations. Each operation creates a progressively more robust representation of the current datum, where situations that should be treated similarly are brought closer together compared with situations that should be treated differently. In other words, each layer of processing filters out noise, i.e., variation that doesn’t meaningfully change the output.
This means that the imitation machines we often use in practice, deep neural nets, turn out to be modular and adaptable, because they’re constructing internal representations (which we call embeddings) that are useful perceptions of each datum.
Footnotes
This can be a quite nuanced and powerful step, though, as StyleGAN shows us.↩︎
I mean this in the sense of the sharpness of the loss surface↩︎