Foundation: Statistical Prediction & ML
Statistical Prediction and Supervised Learning
Before getting to deep learning and large language models, it'll be useful to have a solid grasp on some foundational concepts in probability theory and machine learning. In particular, it helps to understand:
- Random variables, expectations, and variance
- Supervised vs. unsupervised learning
- Regression vs. classification
- Linear models and regularization
- Empirical risk minimization
- Hypothesis classes and bias-variance tradeoffs
For general probability theory, having a solid understanding of how the Central Limit Theorem works is perhaps a reasonable litmus test for how much you'll need to know about random variables before tackling some of the later topics we'll cover. This beautifully-animated 3Blue1Brown video is a great starting point, and there are a couple other good probability videos to check out on the channel if you'd like. This set of course notes from UBC covers the basics of random variables.
If you're into blackboard lectures, I'm a big fan of many of Ryan O'Donnell's CMU courses on YouTube, and this video on random variables and the Central Limit Theorem (from the excellent "CS Theory Toolkit" course) is a nice overview.
For understanding linear models and other key machine learning principles, the first two chapters of Hastie's Elements of Statistical Learning ("Introduction" and "Overview of Supervised Learning") should be enough to get started.
Once you're familiar with the basics, this blog post by anonymous Twitter/X user @ryxcommar does a nice job discussing some common pitfalls and misconceptions related to linear regression. StatQuest on YouTube has a number of videos that might also be helpful.
Many phenomena in the real world are modeled quite well by linear equations — the average temperature over past 7 days is likely a solid guess for the temperature tomorrow, barring any other information about weather pattern forecasts.
Introductions to machine learning tend to emphasize linear models, and for good reason. Linear systems and models are a lot easier to study, interpret, and optimize than their nonlinear counterparts. For more complex and high-dimensional problems with potential nonlinear dependencies between features, it's often useful to ask:
- What's a linear model for the problem?
- Why does the linear model fail?
- What's the best way to add nonlinearity, given the semantic structure of the problem?
In particular, this framing will be helpful for motivating some of the model architectures we'll look at later (e.g. LSTMs and Transformers).
Time-Series Analysis
How much do you need to know about time-series analysis in order to understand the mechanics of more complex generative AI methods?
Short answer: just a tiny bit for LLMs, a good bit more for diffusion.
For modern Transformer-based LLMs, it'll be useful to know:
- The basic setup for sequential prediction problems
- The notion of an autoregressive model
There's not really a coherent way to "visualize" the full mechanics of a multi-billion-parameter model in your head, but much simpler autoregressive models (like ARIMA) can serve as a nice mental model to extrapolate from.
When we get to neural state-space models, a working knowledge of linear time-invariant systems and control theory (which have many connections to classical time-series analysis) will be helpful for intuition, but diffusion is really where it's most essential to dive deeper into into stochastic differential equations to get the full picture. But we can table that for now.
Key Resources
- Blog post: Forecasting with Stochastic Models from Towards Data Science
- Course notes: Time Series Analysis from UAlberta
Online Learning and Regret Minimization
It's debatable how important it is to have a strong grasp on regret minimization, but I think a basic familiarity is useful. The basic setting here is similar to supervised learning, but:
- Points arrive one-at-a-time in an arbitrary order
- We want low average error across this sequence
If you squint and tilt your head, most of the algorithms designed for these problems look basically like gradient descent, often with delicate choices of regularizers and learning rates require for the math to work out. But there's a lot of satisfying math here. I have a soft spot for it, as it relates to a lot of the research I worked on during my PhD. I think it's conceptually fascinating. Like the previous section on time-series analysis, online learning is technically "sequential prediction" but you don't really need it to understand LLMs.
The most direct connection to it that we'll consider is when we look at GANs in Section VIII. There are many deep connections between regret minimization and equilibria in games, and GANs work basically by having two neural networks play a game against each other. Practical gradient-based optimization algorithms like Adam have their roots in this field as well, following the introduction of the AdaGrad algorithm, which was first analyzed for online and adversarial settings.
If you're doing gradient-based optimization with a sensible learning rate schedule, then the order in which you process data points doesn't actually matter much. Gradient descent can handle it.
I'd encourage you to at least skim Chapter 1 of "Introduction to Online Convex Optimization" by Elad Hazan to get a feel for the goal of regret minimization. I've spent a lot of time with this book and I think it's excellent.
Reinforcement Learning
Reinforcement Learning (RL) will come up most directly when we look at finetuning methods in Section IV, and may also be a useful mental model for thinking about "agent" applications and some of the "control theory" notions which come up for state-space models.
Like a lot of the topics discussed in this document, you can go quite deep down many different RL-related threads if you'd like; as it relates to language modeling and alignment, it'll be most important to be comfortable with the basic problem setup for Markov decision processes, notion of policies and trajectories, and high-level understanding of standard iterative + gradient-based optimization methods for RL.
Essential Resources
- Blog post: Lilian Weng's RL Overview - A concise yet comprehensive starting point
- Textbook: Reinforcement Learning: An Introduction by Sutton and Barto
- Video Series: Reinforcement Learning By the Book - 3Blue1Brown-style animations
If you want to jump ahead to some more neural-flavored content, Andrej Karpathy has a nice blog post on deep RL; this manuscript by Yuxi Li and this textbook by Aske Plaat may be useful for further deep dives.
Markov Models
Running a fixed policy in a Markov decision process yields a Markov chain; processes resembling this kind of setup are fairly abundant, and many branches of machine learning involve modeling systems under Markovian assumptions (i.e. lack of path-dependence, given the current state).
This blog post from Aja Hammerly makes a nice case for thinking about language models via Markov processes, and this post from "Essays on Data Science" has examples and code building up towards auto-regressive Hidden Markov Models, which will start to vaguely resemble some of the neural network architectures we'll look at later on.
This blog post from Simeon Carstens gives a nice coverage of Markov chain Monte Carlo methods, which are powerful and widely-used techniques for sampling from implicitly-represented distributions, and are helpful for thinking about probabilistic topics ranging from stochastic gradient descent to diffusion.
Markov models are also at the heart of many Bayesian methods. See this tutorial from Zoubin Ghahramani for a nice overview, the textbook "Pattern Recognition and Machine Learning" for Bayesian angles on many machine learning topics (as well as a more-involved HMM presentation), and this chapter of the Goodfellow et al. "Deep Learning" textbook for some connections to deep learning.
Key Takeaways
- Foundational statistical methods provide essential mental models for understanding LLMs
- Linear models serve as both a starting point and a point of comparison for more complex architectures
- Time-series analysis and autoregressive models help conceptualize how language models work
- Reinforcement learning principles are crucial for understanding fine-tuning and alignment
- Markov models offer a probabilistic framework for thinking about sequential prediction
Neural Networks & Deep Learning
Routes to Understanding Transformers
There are a couple different routes you can take from the basics of neural networks towards Transformers (the dominant architecture for most frontier LLMs in 2024). Once we cover the basics, I'll mostly focus on "deep sequence learning" methods like RNNs. Many deep learning books and courses will more heavily emphasize convolutional neural nets (CNNs), which are quite important for image-related applications and historically were one of the first areas where "scaling" was particularly successful, but technically they're fairly disconnected from Transformers. They'll make an appearance when we discuss state-space models and are definitely important for vision applications, but you'll mostly be okay skipping them for now. However, if you're in a rush and just want to get to the new stuff, you could consider diving right into decoder-only Transformers once you're comfortable with feed-forward neural nets --- this the approach taken by the excellent "Let's build GPT" video from Andrej Karpathy, casting them as an extension of neural n-gram models for next-token prediction. That's probably your single best bet for speedrunning Transformers in under 2 hours. But if you've got a little more time, understanding the history of RNNs, LSTMs, and encoder-decoder Transformers is certainly worthwhile.
This section is mostly composed of signposts to content from the following sources (along with some blog posts):
- The "Dive Into Deep Learning" (d2l.ai) interactive textbook (nice graphics, in-line code, some theory)
- 3Blue1Brown's "Neural networks" video series (lots of animations)
- Andrej Karpathy's "Zero to Hero" video series (live coding + great intuitions)
- "StatQuest with Josh Starmer" videos
- The Goodfellow et al. "Deep Learning" textbook (theory-focused, no Transformers)
If your focus is on applications, you might find the interactive "Machine Learning with PyTorch and Scikit-Learn" book useful, but I'm not as familiar with it personally.
For these topics, you can also probably get away with asking conceptual questions to your preferred LLM chat interface. This likely won't be true for later sections --- some of those topics were introduced after the knowledge cutoff dates for many current LLMs, and there's also just a lot less text on the internet about them, so you end up with more "hallucinations".
Statistical Prediction with Neural Networks
I'm not actually sure where I first learned about neural nets --- they're pervasive enough in technical discussions and general online media that I'd assume you've picked up a good bit through osmosis even if you haven't studied them formally. Nonetheless, there are many worthwhile explainers out there, and I'll highlight some of my favorites.
- The first 4 videos in 3Blue1Brown's "Neural networks" series will take you from basic definitions up through the mechanics of backpropagation.
- This blog post from Andrej Karpathy (back when he was a PhD student) is a solid crash-course, well-accompanied by his video on building backprop from scratch.
- This blog post from Chris Olah has a nice and concise walk-through of the math behind backprop for neural nets.
- Chapters 3-5 of the d2l.ai book are great as a "classic textbook" presentation of deep nets for regression + classification, with code examples and visualizations throughout.
Recurrent Neural Networks
RNNs are where we start adding "state" to our models (as we process increasingly long sequences), and there are some high-level similarities to hidden Markov models. This blog post from Andrej Karpathy is a good starting point. Chapter 9 of the d2l.ai book is great for main ideas and code; check out Chapter 10 of "Deep Learning" if you want more theory.
For videos, here's a nice one from StatQuest.
LSTMs and GRUs
Long Short-Term Memory (LSTM) networks and Gated Recurrent Unit (GRU) networks build upon RNNs with more specialized mechanisms for state representation (with semantic inspirations like "memory", "forgetting", and "resetting"), which have been useful for improving performance in more challenging data domains (like language).
Chapter 10 of d2l.ai covers both of these quite well (up through 10.3). The "Understanding LSTM Networks" blog post from Chris Olah is also excellent. This video from "The AI Hacker" gives solid high-level coverage of both; StatQuest also has a video on LSTMs, but not GRUs. GRUs are essentially a simplified alternative to LSTMs with the same basic objective, and it's up to you if you want to cover them specifically.
Neither LSTMs or GRUs are really prerequisites for Transformers, which are "stateless", but they're useful for understanding the general challenges of neural sequence modeling and contextualizing the Transformer design choices.
They'll also help motivate some of the approaches towards addressing the "quadratic scaling problem" in Section VII.
Embeddings and Topic Modeling
Before digesting Transformers, it's worth first establishing a couple concepts which will be useful for reasoning about what's going on under the hood inside large language models. While deep learning has led to a large wave of progress in NLP, it's definitely a bit harder to reason about than some of the "old school" methods which deal with word frequencies and n-gram overlaps; however, even though these methods don't always scale to more complex tasks, they're useful mental models for the kinds of "features" that neural nets might be learning. For example, it's certainly worth knowing about Latent Dirichlet Allocation for topic modeling (blog post) and tf-idf to get a feel for what numerical similarity or relevance scores can represent for language.
Thinking about words (or tokens) as high-dimensional "meaning" vectors is quite useful, and the Word2Vec embedding method illustrates this quite well --- you may have seen the classic "King - Man + Woman = Queen" example referenced before. "The Illustrated Word2Vec" from Jay Alammar is great for building up this intuition, and these course notes from Stanford's CS224n are excellent as well. Here's also a nice video on Word2Vec from ritvikmath, and another fun one video on neural word embeddings from Computerphile.
Key Resources for Embeddings
- Blog post: The Illustrated Word2Vec - Visual guide to word embeddings
- Video: Neural Word Embeddings from Computerphile
- Paper: Efficient Estimation of Word Representations in Vector Space - Original Word2Vec paper
Beyond being a useful intuition and an element of larger language models, standalone neural embedding models are also widely used today. Often these are encoder-only Transformers, trained via "contrastive loss" to construct high-quality vector representations of text inputs which are useful for retrieval tasks (like RAG). See this post+video from Cohere for a brief overview, and this blog post from Lilian Weng for more of a deep dive.
Encoders and Decoders
Up until now we've been pretty agnostic as to what the inputs to our networks are --- numbers, characters, words --- as long as it can be converted to a vector representation somehow. Recurrent models can be configured to both input and output either a single object (e.g. a vector) or an entire sequence. This observation enables the sequence-to-sequence encoder-decoder architecture, which rose to prominence for machine translation, and was the original design for the Transformer in the famed "Attention is All You Need" paper. Here, the goal is to take an input sequence (e.g. an English sentence), "encode" it into a vector object which captures its "meaning", and then "decode" that object into another sequence (e.g. a French sentence). Chapter 10 in d2l.ai (10.6-10.8) covers this setup as well, which sets the stage for the encoder-decoder formulation of Transformers in Chapter 11 (up through 11.7). For historical purposes you should certainly at least skim the original paper, though you might get a bit more out of the presentation of its contents via "The Annotated Transformer", or perhaps "The Illustrated Transformer" if you want more visualizations. These notes from Stanford's CS224n are great as well.
There are videos on encoder-decoder architectures and Attention from StatQuest, a full walkthrough of the original Transformer by The AI Hacker.
However, note that these encoder-decoder Transformers differ from most modern LLMs, which are typically "decoder-only" -- if you're pressed for time, you may be okay jumping right to these models and skipping the history lesson.
Decoder-Only Transformers
There's a lot of moving pieces inside of Transformers --- multi-head attention, skip connections, positional encoding, etc. --- and it can be tough to appreciate it all the first time you see it. Building up intuitions for why some of these choices are made helps a lot, and here I'll recommend to pretty much anyone that you watch a video or two about them (even if you're normally a textbook learner), largely because there are a few videos which are really excellent:
- 3Blue1Brown's "But what is a GPT?" and "Attention in transformers, explained visually" -- beautiful animations + discussions, supposedly a 3rd video is on the way
- Andrej Karpathy's "Let's build GPT" video -- live coding and excellent explanations, really helped some things "click" for me
Here's a blog post from Cameron Wolfe walking through the decoder-only architecture in a similar style to the Illustrated/Annotated Transformer posts. There's also a nice section in d2l.ai (11.9) covering the relationships between encoder-only, encoder-decoder, and decoder-only Transformers.
Key Takeaways
- Neural networks form the foundation of modern deep learning methods
- Recurrent neural networks (RNNs) and their variants introduced state for sequence modeling
- Word embeddings provide a way to represent semantic meaning in a vector space
- Encoder-decoder architectures were a critical step toward modern Transformers
- Decoder-only Transformers are the architecture behind most modern LLMs
LLM Architecture & Training
Tokenization
Character-level tokenization (like in several of the Karpathy videos) tends to be inefficient for large-scale Transformers vs. word-level tokenization, yet naively picking a fixed "dictionary" (e.g. Merriam-Webster) of full words runs the risk of encountering unseen words or misspellings at inference time. Instead, the typical approach is to use subword-level tokenization to "cover" the space of possible inputs, while maintaining the efficiency gains which come from a larger token pool, using algorithms like Byte-Pair Encoding (BPE) to select the appropriate set of tokens. If you've ever seen Huffman coding in an introductory algorithms class I think it's a somewhat useful analogy for BPE here, although the input-output format is notably different, as we don't know the set of "tokens" in advance. I'd recommend watching Andrej Karpathy's video on tokenization and checking out this tokenization guide from Masato Hagiwara.
Positional Encoding
As we saw in the past section, Transformers don't natively have the same notion of adjacency or position within a context windows (in contrast to RNNs), and position must instead represented with some kind of vector encoding. While this could be done naively with something like one-hot encoding, this is impractical for context-scaling and suboptimal for learnability, as it throws away notions of ordinality. Originally, this was done with sinusoidal positional encodings, which may feel reminiscent of Fourier features if you're familiar; the most popular implementation of this type of approach nowadays is likely Rotary Positional Encoding, or RoPE, which tends to be more stable and faster to learn during training.
Key Resources for Positional Encoding
- Blog post: Understanding Positional Embeddings by Harrison Pim on intuition for positional encodings
- Blog post: A Gentle Introduction to Positional Encoding by Mehreen Saeed on the original Transformer positional encodings
- Blog post: Rotary Embeddings on RoPE from Eleuther AI
- Animated video: Understanding Positional Encoding from DeepLearning Hero
Pretraining Recipes
Once you've committed to pretraining a LLM of a certain general size on a particular corpus of data (e.g Common Crawl, FineWeb), there are still a number of choices to make before you're ready to go:
- Attention mechanisms (multi-head, multi-query, grouped-query)
- Activations (ReLU, GeLU, SwiGLU)
- Optimizers, learning rates, and schedulers (AdamW, warmup, cosine decay)
- Dropout?
- Hyperparameter choices and search strategies
- Batching, parallelization strategies, gradient accumulation
- How long to train for, how often to repeat data
- ...and many other axes of variation
As far as I can tell, there's not a one-size-fits-all rule book for how to go about this, but the resources below provide valuable insights from those who have navigated these challenges.
Essential Pretraining Resources
- Blog post: A Recipe for Training Neural Networks by Andrej Karpathy - While it predates the LLM era, this is a great starting point for framing many problems relevant throughout deep learning
- Guide: The Novice's LLM Training Guide by Alpin Dale, discussing hyperparameter choices in practice, as well as the finetuning techniques we'll see in future sections
- Blog post: How to train your own Large Language Models from Replit has some nice discussions on data pipelines and evaluations for training
- Article: Navigating the Attention Landscape: MHA, MQA, and GQA Decoded by Shobhit Agarwal for understanding attention mechanism tradeoffs
- Blog post: The Evolution of the Modern Transformer from Deci AI for discussion of "popular defaults"
- Chapter: Learning Rate Scheduling from the d2l.ai book (Chapter 12.11)
- Blog post: Response to NYT from Eleuther AI on controversy surrounding reporting of "best practices"
Distributed Training and FSDP
There are a number of additional challenges associated with training models which are too large to fit on individual GPUs (or even multi-GPU machines), typically necessitating the use of distributed training protocols like Fully Sharded Data Parallelism (FSDP), in which models can be co-located across machines during training. It's probably worth also understanding its precursor Distributed Data Parallelism (DDP), which is covered in the first post linked below.
Resources on Distributed Training
- Blog post: FSDP from Meta (who pioneered the method)
- Blog post: Understanding FSDP by Bar Rozenman, featuring many excellent visualizations
- Report: Training Great LLMs Entirely From Ground Zero in the Wilderness from Yi Tai on the challenges of pretraining a model in a startup environment
- Technical blog: FSDP QLora Deep Dive from Answer.AI on combining FSDP with parameter-efficient finetuning techniques for use on consumer GPUs
Scaling Laws
It's useful to know about scaling laws as a meta-topic which comes up a lot in discussions of LLMs (most prominently in reference to the "Chinchilla" paper), more so than any particular empirical finding or technique. In short, the performance which will result from scaling up the model, data, and compute used for training a language model results in fairly reliable predictions for model loss. This then enables calibration of optimal hyperparameter settings without needing to run expensive grid searches.
Resources on Scaling Laws
- Blog overview: Chinchilla Scaling Laws for Large Language Models by Rania Hossam
- Discussion: New Scaling Laws for LLMs on LessWrong
- Post: Chinchilla's Wild Implications on LessWrong
- Analysis: Chinchilla Scaling: A Replication Attempt (potential issues with Chinchilla findings)
- Blog post: Scaling Laws and Emergent Properties by Clément Thiriet
- Video lecture: Scaling Language Models from Stanford CS224n
Mixture-of-Experts
While many of the prominent LLMs (such as Llama3) used today are "dense" models (i.e. without enforced sparsification), Mixture-of-Experts (MoE) architectures are becoming increasingly popular for navigating tradeoffs between "knowledge" and efficiency, used perhaps most notably in the open-weights world by Mistral AI's "Mixtral" models (8x7B and 8x22B), and rumored to be used for GPT-4. In MoE models, only a fraction of the parameters are "active" for each step of inference, with trained router modules for selecting the parallel "experts" to use at each layer. This allows models to grow in size (and perhaps "knowlege" or "intelligence") while remaining efficient for training or inference compared to a comparably-sized dense model.
Resources on Mixture-of-Experts
- Blog post: Mixture of Experts Explained from Hugging Face for a technical overview
- Video: Mixture of Experts Visualized from Trelis Research for a visualized explainer
Key Takeaways
- Subword tokenization strikes a balance between efficiency and handling unknown words
- Positional encoding schemes like RoPE are crucial for Transformers to understand sequence order
- LLM pretraining involves numerous architecture and optimization decisions
- Distributed training techniques like FSDP enable training of models too large for individual GPUs
- Scaling laws provide guidance on optimal allocation of compute, data, and model size
- Mixture-of-Experts models offer parameter efficiency by activating only relevant parameters during inference
Finetuning & Alignment
Instruct Fine-Tuning
Instruct fine-tuning (or "instruction tuning", or "supervised finetuning", or "chat tuning" -- the boundaries here are a bit fuzzy) is the primary technique used (at least initially) for coaxing LLMs to conform to a particular style or format. Here, data is presented as a sequence of (input, output) pairs where the input is a user question to answer, and the model's goal is to predict the output -- typically this also involves adding special "start"/"stop"/"role" tokens and other masking techniques, enabling the model to "understand" the difference between the user's input and its own outputs. This technique is also widely used for task-specific finetuning on datasets with a particular kind of problem structure (e.g. translation, math, general question-answering).
See this blog post from Sebastian Ruder or this video from Shayne Longpre for short overviews.
Low-Rank Adapters (LoRA)
While pre-training (and "full finetuning") requires applying gradient updates to all parameters of a model, this is typically impractical on consumer GPUs or home setups; fortunately, it's often possible to significantly reduce the compute requirements by using parameter-efficient finetuning (PEFT) techniques like Low-Rank Adapters (LoRA). This can enable competitive performance even with relatively small datasets, particularly for application-specific use cases. The main idea behind LoRA is to train each weight matrix in a low-rank space by "freezing" the base matrix and training a factored representation with much smaller inner dimension, which is then added to the base matrix.
Resources on LoRA
- Video: LoRA paper walkthrough (part 1)
- Video: LoRA code demo (part 2)
- Blog post: "Parameter-Efficient LLM Finetuning With Low-Rank Adaptation" by Sebastian Raschka
- Blog post: "Practical Tips for Finetuning LLMs Using LoRA" by Sebastian Raschka
Additionally, an "decomposed" LoRA variant called DoRA has been gaining popularity in recent months, often yielding performance improvements; see this post from Sebastian Raschka for more details.
Reward Models and RLHF
One of the most prominent techniques for "aligning" a language model is Reinforcement Learning from Human Feedback (RLHF); here, we typically assume that an LLM has already been instruction-tuned to respect a chat style, and that we additionally have a "reward model" which has been trained on human preferences. Given pairs of differing outputs to an input, where a preferred output has been chosen by a human, the learning objective of the reward model is to predict the preferred output, which involves implicitly learning preference "scores". This allows bootstrapping a general representation of human preferences (at least with respect to the dataset of output pairs), which can be used as a "reward simulator" for continual training of a LLM using RL policy gradient techniques like PPO.
RLHF represents a significant advancement in aligning LLMs with human values and preferences, enabling models to produce outputs that are not just factually accurate but also helpful, harmless, and honest.
Resources on RLHF
- Blog post: "Illustrating Reinforcement Learning from Human Feedback (RLHF)" from Hugging Face
- Blog post: "Reinforcement Learning from Human Feedback" from Chip Huyen
- Video: RLHF talk by Nathan Lambert
- Blog post: Insights on RewardBench from Sebastian Raschka
Direct Preference Optimization Methods
The space of alignment algorithms seems to be following a similar trajectory as we saw with stochastic optimization algorithms a decade ago. In this an analogy, RLHF is like SGD --- it works, it's the original, and it's also become kind of a generic "catch-all" term for the class of algorithms that have followed it. Perhaps DPO is AdaGrad, and in the year since its release there's been a rapid wave of further algorithmic developments along the same lines (KTO, IPO, ORPO, etc.), whose relative merits are still under active debate. Maybe a year from now, everyone will have settled on a standard approach which will become the "Adam" of alignment.
Resources on DPO
- Blog post: "Understanding the Implications of Direct Preference Optimization" by Matthew Gunton
- Blog post: "Fine-tuning language models with Direct Preference Optimization" from Hugging Face
- Blog post: "The Art of Preference Optimization" from Hugging Face (comparing DPO-flavored methods)
Context Scaling
Beyond task specification or alignment, another common goal of finetuning is to increase the effective context length of a model, either via additional training, adjusting parameters for positional encodings, or both. Even if adding more tokens to a model's context can "type-check", training on additional longer examples is generally necessary if the model may not have seen such long sequences during pretraining.
Resources on Context Scaling
- Blog post: "Scaling Rotational Embeddings for Long-Context Language Models" by Gradient AI
- Blog post: "Extending the RoPE" by Eleuther AI, introducing the YaRN method for increased context via attention temperature scaling
- Blog post: "Everything About Long Context Fine-tuning" by Wenbo Pan
Distillation and Merging
Here we'll look at two very different methods of consolidating knowledge across LLMs --- distillation and merging. Distillation was first popularized for BERT models, where the goal is to "distill" the knowledge and performance of a larger model into a smaller one (at least for some tasks) by having it serve as a "teacher" during the smaller model's training, bypassing the need for large quantities of human-labeled data.
Resources on Distillation
- Blog post: "Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT" from Hugging Face
- Guide: "LLM distillation demystified: a complete guide" from Snorkel AI
- Research blog: "Distilling Step by Step" from Google Research
Merging is much more of a "wild west" technique, largely used by open-source engineers who want to combine the strengths of multiple finetuning efforts. It's kind of wild that it works at all, and perhaps grants some credence to "linear representation hypotheses".
The idea behind model merging is basically to take two different finetunes of the same base model and just average their weights. No training required. Technically, it's usually "spherical interpolation" (or "slerp"), but this is pretty much just fancy averaging with a normalization step. For more details, see the post Merge Large Language Models with mergekit by Maxime Labonne.
Key Takeaways
- Instruct fine-tuning transforms base LLMs into models that follow user instructions
- Parameter-efficient techniques like LoRA make fine-tuning feasible on consumer hardware
- RLHF aligns models with human preferences through reward modeling and reinforcement learning
- Direct Preference Optimization (DPO) offers a simpler alternative to RLHF for alignment
- Context scaling techniques enable LLMs to handle much longer inputs than their pretraining allowed
- Knowledge distillation creates smaller, faster models that retain much of their teacher's capabilities
- Model merging can combine strengths from different fine-tuned models without additional training
Applications & Interpretability
Before diving into the individual chapters, I'd recommend these two high-level overviews, which touch on many of the topics we'll examine here:
- "Building LLM applications for production" by Chip Huyen
- "What We Learned from a Year of Building with LLMs" Part 1 and Part 2 from O'Reilly (several authors)
These web courses also have a lot of relevant interactive materials:
- "Large Language Model Course" from Maxime Labonne
- "Generative AI for Beginners" from Microsoft
Benchmarking
Beyond the standard numerical performance measures used during LLM training like cross-entropy loss and perplexity, the true performance of frontier LLMs is more commonly judged according to a range of benchmarks, or "evals". Common types of these are:
- Human-evaluated outputs (e.g. LMSYS Chatbot Arena)
- AI-evaluated outputs (as used in RLAIF)
- Challenge question sets (e.g. those in HuggingFace's LLM Leaderboard)
Resources on Benchmarking
- Slides: LLM Evaluation from Stanford's CS224n
- Blog post: "How to evaluate LLMs" by Jason Wei
- Blog post: "Evaluating LLM Apps" by Peter Hayes
- Documentation: Inspect-AI framework with guidance on designing benchmarks and reliable evaluation pipelines
Sampling and Structured Outputs
While typical LLM inference samples tokens one at a time, there are number of parameters controlling the token distribution (temperature, top_p, top_k) which can be modified to control the variety of responses, as well as non-greedy decoding strategies that allow some degree of "lookahead". This blog post by Maxime Labonne does a nice job discussing several of them.
Sometimes we want our outputs to follow a particular structure, particularly if we are using LLMs as a component of a larger system rather than as just a chat interface. Few-shot prompting works okay, but not all the time, particularly as output schemas become more complicated.
For schema types like JSON, Pydantic and Outlines are popular tools for constraining the output structure from LLMs. Some useful resources:
Resources on Structured Outputs
- Documentation: Pydantic Concepts
- Documentation: Outlines for JSON
- Review: Outlines Demo and Review by Michael Wornow
Prompting Techniques
There are many prompting techniques, and many more prompt engineering guides out there, featuring methods for coaxing more desirable outputs from LLMs. Some of the classics:
- Few-Shot Examples
- Chain-of-Thought
- Retrieval-Augmented Generation (RAG)
- ReAct
Resources on Prompting
- Blog post: "Prompt Engineering" by Lilian Weng - discusses several of the most dominant approaches
- Guide: Prompt Engineering Guide - decent coverage and examples for a wider range of prominent techniques
Vector Databases and Reranking
RAG systems require the ability to quickly retrieve relevant documents from large corpuses. Relevancy is typically determined by similarity measures for semantic embedding vectors of both queries and documents, such as cosine similarity or Euclidean distance. If we have just a handful of documents, this can be computed between a query and each document, but this quickly becomes intractable when the number of documents grows large. This is the problem addressed by vector databases, which allow retrieval of the _approximate_ top-K matches (significantly faster than checking all pairs) by maintaining high-dimensional indices over vectors which efficiently encode their geometric structure.
Resources on Vector Databases
- Documentation: Pinecone on Vector Search Methods - walks through methods like Locality-Sensitive Hashing and Hierarchical Navigable Small Worlds
- Talk: Vector Databases Overview by Alexander Chatzizacharias
- Documentation: Reranking in RAG from Pinecone - overview of optimizing for metrics beyond query similarity
Retrieval-Augmented Generation
One of the most buzzed-about uses of LLMs over the past year, retrieval-augmented generation (RAG) is how you can "chat with a PDF" (if larger than a model's context) and how applications like Perplexity and Arc Search can "ground" their outputs using web sources. This retrieval is generally powered by embedding each document for storage in a vector database + querying with the relevant section of a user's input.
Resources on RAG
- Blog post: "Deconstructing RAG" from Langchain
- Blog post: "Building RAG with Open-Source and Custom AI Models" from Chaoyu Yang
- Course: Advanced RAG video course from DeepLearning.AI
Tool Use and "Agents"
The other big application buzzwords you've most likely encountered in some form are "tool use" and "agents", or "agentic programming". This typically starts with the ReAct framework we saw in the prompting section, then gets extended to elicit increasingly complex behaviors like software engineering (see the much-buzzed-about "Devin" system from Cognition, and several related open-source efforts like Devon/OpenDevin/SWE-Agent). There are many programming frameworks for building agent systems on top of LLMs, with Langchain and LlamaIndex being two of the most popular.
There also seems to be some value in having LLMs rewrite their own prompts + evaluate their own partial outputs; this observation is at the heart of the DSPy framework (for "compiling" a program's prompts, against a reference set of instructions or desired outputs) which has recently been seeing a lot of attention.
Resources on Agent Systems
- Blog post: "LLM Powered Autonomous Agents" from Lilian Weng
- Guide: "A Guide to LLM Abstractions" from Two Sigma
- Video: "DSPy Explained!" by Connor Shorten
- Blog post: LLMs with Knowledge Graphs from Neo4J
- Blog post: Data Wrangling with LLMs from Numbers Station
LLMs for Synthetic Data
An increasing number of applications are making use of LLM-generated data for training or evaluations, including distillation, dataset augmentation, AI-assisted evaluation and labeling, self-critique, and more.
Resources on Synthetic Data
- Guide: Synthetic Data for RAG - demonstrates how to construct synthetic datasets
- Blog post: RLAIF Overview from Argilla - AI-assisted feedback as an alternative to RLHF
- Blog post: Constitutional AI from Anthropic - overview of AI-assisted feedback for alignment
Representation Engineering
Representation Engineering is a new and promising technique for fine-grained steering of language model outputs via "control vectors". Somewhat similar to LoRA adapters, it has the effect of adding low-rank biases to the weights of a network which can elicit particular response styles (e.g. "humorous", "verbose", "creative", "honest"), yet is much more computationally efficient and can be implemented without any training required.
The method simply looks at differences in activations for pairs of inputs which vary along the axis of interest (e.g. honesty), which can be generated synthetically, and then performs dimensionality reduction.
Resources on Representation Engineering
- Blog post: Representation Engineering Overview from Center for AI Safety
- Blog post: Technical Deep-Dive with Code from Theia Vogel
- Podcast: Representation Engineering Explained with Theia Vogel
Mechanistic Interpretability
Mechanistic Interpretability (MI) is the dominant paradigm for understanding the inner workings of LLMs by identifying sparse representations of "features" or "circuits" encoded in model weights. Beyond enabling potential modification or explanation of LLM outputs, MI is often viewed as an important step towards potentially "aligning" increasingly powerful systems.
Resources on Mechanistic Interpretability
- Guide: "A Comprehensive Mechanistic Interpretability Explainer & Glossary" by Neel Nanda
- List: "An Extremely Opinionated Annotated List of My Favourite Mechanistic Interpretability Papers" by Neel Nanda
- Guide: "Mechanistic Interpretability Quickstart Guide" (Neel Nanda on LessWrong)
- Discussion: "How useful is mechanistic interpretability?" (Neel and others on LessWrong)
- Spreadsheet: "200 Concrete Problems In Interpretability" (Annotated open problems from Neel)
- Article: "Toy Models of Superposition" from Anthropic
- Article: "Scaling Monosemanticity" from Anthropic
Linear Representation Hypotheses
An emerging theme from several lines of interpretability research has been the observation that internal representations of features in Transformers are often "linear" in high-dimensional space (a la Word2Vec). On one hand this may appear initially surprising, but it's also essentially an implicit assumption for techniques like similarity-based retrieval, merging, and the key-value similarity scores used by attention.
Resources on Linear Representations
- Blog post: "Deep Learning Models are Secretly Linear" by Beren Millidge
- Talk: Linear Representations in LLMs from Kiho Park
- Paper: "Language Models Represent Space and Time" - worth skimming for its figures
Key Takeaways
- Benchmarking LLMs requires multiple evaluation approaches including human feedback, AI evaluation, and challenge sets
- Sampling parameters and structured output tools help control LLM response characteristics
- Advanced prompting techniques like Chain-of-Thought and ReAct substantially improve performance
- Vector databases enable efficient semantic search critical for RAG applications
- Retrieval-Augmented Generation grounds LLM outputs in external knowledge
- Agent frameworks extend LLMs with tool-use capabilities for complex tasks
- LLM-generated synthetic data enables training improvements without human labeling
- Representation Engineering offers lightweight control over LLM behaviors
- Mechanistic Interpretability seeks to understand the internal workings of LLMs
- Linear representation of features in transformers enables many practical applications
Inference Optimization
Parameter Quantization
With the rapid increase in parameter counts for leading LLMs and difficulties (both in cost and availability) in acquiring GPUs to run models on, there's been a growing interest in quantizing LLM weights to use fewer bits each, which can often yield comparable output quality with a 50-75% (or more) reduction in required memory. Typically this shouldn't be done naively; Tim Dettmers, one of the pioneers of several modern quantization methods (LLM.int8(), QLoRA, bitsandbytes) has a great blog post for understanding quantization principles, and the need for mixed-precision quantization as it relates to emergent features in large-model training.
Effective quantization can reduce memory requirements by 50-75% while maintaining comparable output quality, making large models accessible on consumer hardware.
Resources on Quantization
- Blog post: Understanding Quantization Principles by Tim Dettmers
- Overview: What are Quantized LLMs from TensorOps - covers GGUF, AWQ, HQQ, and GPTQ
- Blog post: Quantization Methods Comparison by Maarten Grootendorst
- Talk: QLoRA Overview by Tim Dettmers
- Blog: 4-bit Transformers with bitsandbytes from Hugging Face
- Technical post: FSDP QLoRA Deep Dive from Answer.AI - combining QLoRA with FSDP for efficient finetuning
Speculative Decoding
The basic idea behind speculative decoding is to speed up inference from a larger model by primarily sampling tokens from a much smaller model and occasionally applying corrections (e.g. every N tokens) from the larger model whenever the output distributions diverge. These batched consistency checks tend to be much faster than sampling N tokens directly, and so there can be large overall speedups if the token sequences from smaller model only diverge periodically.
Resources on Speculative Decoding
- Blog post: Speculative Sampling Walkthrough from Jay Mody
- Article: Hitchhiker's Guide to Speculative Decoding from PyTorch - includes evaluation results
- Video: Speculative Decoding Overview from Trelis Research
FlashAttention
Computing attention matrices tends to be a primary bottleneck in inference and training for Transformers, and FlashAttention has become one of the most widely-used techniques for speeding it up. In contrast to some of the techniques we'll see in Section 7 which approximate attention with a more concise representation (occurring some representation error as a result), FlashAttention is an exact representation whose speedup comes from hardware-aware implementation.
FlashAttention applies tiling and recomputation to decompose the expression of attention matrices, enabling significantly reduced memory I/O and faster wall-clock performance (even while slightly increasing the required FLOPS).
Resources on FlashAttention
- Talk: FlashAttention Explained by Tri Dao (author of FlashAttention)
- Explainer: ELI5: FlashAttention by Aleksa Gordić
Key-Value Caching and Paged Attention
As noted in the NVIDIA blog referenced above, key-value caching is fairly standard in Transformer implementation matrices to avoid redundant recomputation of attention. This enables a tradeoff between speed and resource utilization, as these matrices are kept in GPU VRAM. While managing this is fairly straightforward for a single "thread" of inference, a number of complexities arise when considering parallel inference or multiple users for a single hosted model instance.
How can you avoid recomputing values for system prompts and few-shot examples? When should you evict cache elements for a user who may or may not want to continue a chat session? PagedAttention addresses these challenges by leveraging ideas from classical paging in operating systems.
PagedAttention and its popular implementation vLLM has become a standard for self-hosted multi-user inference servers.
Resources on KV Caching
- Video: The KV Cache: Memory Usage in Transformers by Efficient NLP
- Video: Fast LLM Serving with vLLM and PagedAttention by Anyscale
- Blog post: vLLM: Easy, Fast, and Cheap LLM Serving from vLLM team
CPU Offloading
The primary method used for running LLMs either partially or entirely on CPU (vs. GPU) is llama.cpp. This approach is particularly valuable for those without access to high-end GPUs or for deployment in resource-constrained environments.
Resources on CPU Offloading
- Tutorial: Llama.cpp Tutorial from DataCamp - high-level overview
- Blog post: CPU Matrix Multiplication Optimizations - technical details about CPU performance improvements
- Note: llama.cpp serves as the backend for popular self-hosted LLM tools like LMStudio and Ollama
Key Takeaways
- Parameter quantization makes large models accessible on consumer hardware with minimal quality loss
- Speculative decoding accelerates inference by using smaller models to "draft" outputs for larger models
- FlashAttention significantly speeds up attention computation through hardware-aware implementation
- Key-value caching avoids redundant computation during autoregressive decoding
- PagedAttention enables efficient memory management for multi-user inference
- CPU offloading techniques like llama.cpp allow running models without dedicated GPU hardware
Parameter Quantization
With the rapid increase in parameter counts for leading LLMs and difficulties (both in cost and availability) in acquiring GPUs to run models on, there's been a growing interest in quantizing LLM weights to use fewer bits each, which can often yield comparable output quality with a 50-75% (or more) reduction in required memory. Typically this shouldn't be done naively; Tim Dettmers, one of the pioneers of several modern quantization methods (LLM.int8(), QLoRA, bitsandbytes) has a great blog post for understanding quantization principles, and the need for mixed-precision quantization as it relates to emergent features in large-model training.
Effective quantization can reduce memory requirements by 50-75% while maintaining comparable output quality, making large models accessible on consumer hardware.
Resources on Quantization
- Blog post: Understanding Quantization Principles by Tim Dettmers
- Overview: What are Quantized LLMs from TensorOps - covers GGUF, AWQ, HQQ, and GPTQ
- Blog post: Quantization Methods Comparison by Maarten Grootendorst
- Talk: QLoRA Overview by Tim Dettmers
- Blog: 4-bit Transformers with bitsandbytes from Hugging Face
- Technical post: FSDP QLoRA Deep Dive from Answer.AI - combining QLoRA with FSDP for efficient finetuning
Speculative Decoding
The basic idea behind speculative decoding is to speed up inference from a larger model by primarily sampling tokens from a much smaller model and occasionally applying corrections (e.g. every N tokens) from the larger model whenever the output distributions diverge. These batched consistency checks tend to be much faster than sampling N tokens directly, and so there can be large overall speedups if the token sequences from smaller model only diverge periodically.
Resources on Speculative Decoding
- Blog post: Speculative Sampling Walkthrough from Jay Mody
- Article: Hitchhiker's Guide to Speculative Decoding from PyTorch - includes evaluation results
- Video: Speculative Decoding Overview from Trelis Research
FlashAttention
Computing attention matrices tends to be a primary bottleneck in inference and training for Transformers, and FlashAttention has become one of the most widely-used techniques for speeding it up. In contrast to some of the techniques we'll see in Section 7 which approximate attention with a more concise representation (occurring some representation error as a result), FlashAttention is an exact representation whose speedup comes from hardware-aware implementation.
FlashAttention applies tiling and recomputation to decompose the expression of attention matrices, enabling significantly reduced memory I/O and faster wall-clock performance (even while slightly increasing the required FLOPS).
Resources on FlashAttention
- Talk: FlashAttention Explained by Tri Dao (author of FlashAttention)
- Explainer: ELI5: FlashAttention by Aleksa Gordić
Key-Value Caching and Paged Attention
As noted in the NVIDIA blog referenced above, key-value caching is fairly standard in Transformer implementation matrices to avoid redundant recomputation of attention. This enables a tradeoff between speed and resource utilization, as these matrices are kept in GPU VRAM. While managing this is fairly straightforward for a single "thread" of inference, a number of complexities arise when considering parallel inference or multiple users for a single hosted model instance.
How can you avoid recomputing values for system prompts and few-shot examples? When should you evict cache elements for a user who may or may not want to continue a chat session? PagedAttention addresses these challenges by leveraging ideas from classical paging in operating systems.
PagedAttention and its popular implementation vLLM has become a standard for self-hosted multi-user inference servers.
Resources on KV Caching
- Video: The KV Cache: Memory Usage in Transformers by Efficient NLP
- Video: Fast LLM Serving with vLLM and PagedAttention by Anyscale
- Blog post: vLLM: Easy, Fast, and Cheap LLM Serving from vLLM team
CPU Offloading
The primary method used for running LLMs either partially or entirely on CPU (vs. GPU) is llama.cpp. This approach is particularly valuable for those without access to high-end GPUs or for deployment in resource-constrained environments.
Resources on CPU Offloading
- Tutorial: Llama.cpp Tutorial from DataCamp - high-level overview
- Blog post: CPU Matrix Multiplication Optimizations - technical details about CPU performance improvements
- Note: llama.cpp serves as the backend for popular self-hosted LLM tools like LMStudio and Ollama
Key Takeaways
- Parameter quantization makes large models accessible on consumer hardware with minimal quality loss
- Speculative decoding accelerates inference by using smaller models to "draft" outputs for larger models
- FlashAttention significantly speeds up attention computation through hardware-aware implementation
- Key-value caching avoids redundant computation during autoregressive decoding
- PagedAttention enables efficient memory management for multi-user inference
- CPU offloading techniques like llama.cpp allow running models without dedicated GPU hardware
Addressing the Quadratic Scaling Problem
Sliding Window Attention
Introduced in the "Longformer" paper, sliding window attention acts as a sub-quadratic drop-in replacement for standard attention which allows attending only to a sliding window (shocking, right?) of recent tokens/states rather than the entire context window, under the pretense that vectors for these states have already attended to earlier ones and thus have sufficient representational power to encode relevant pieces of early context. Due to its simplicity, it's become one of the more widely adopted approaches towards sub-quadratic scaling, and is used in Mistral's popular Mixtral-8x7B model (among others).
Resources on Sliding Window Attention
- Blog post: "What is Sliding Window Attention?" by Stephen M. Walker
- Blog post: "Sliding Window Attention" by Manoj Kumal
- Video: "Longformer: The Long-Document Transformer" by Yannic Kilcher
Ring Attention
Another modification to standard attention mechanisms, Ring Attention enables sub-quadratic full-context interaction via incremental computation with a "message-passing" structure, wherein "blocks" of context communicate with each other over a series of steps rather than all at once. Within each block, the technique is essentially classical attention.
While largely a research direction rather than standard technique at least within the open-weights world, Google's Gemini is rumored to possibly be using Ring Attention in order to enable its million-plus-token context.
Resources on Ring Attention
- Blog post: "Breaking the Boundaries: Understanding Context Window Limitations and the idea of Ring Attention" by Tanuj Sharma
- Blog post: "Understanding Ring Attention: Building Transformers With Near-Infinite Context" from E2E Networks
- Video: "Ring Attention Explained"
Linear Attention (RWKV)
The Receptance-Weighted Key Value (RWKV) architecture is a return to the general structure of RNN models (e.g LSTMs), with modifications to enable increased scaling and a linear attention-style mechanism which supports recurrent "unrolling" of its representation (allowing constant computation per output token as context length scales).
Resources on RWKV
- Blog post: "Getting Started With RWKV" from Hugging Face
- Blog post: "The RWKV language model: An RNN with the advantages of a transformer" - Pt. 1 by Johan Wind
- Blog post: "How the RWKV language model works" - Pt. 2 by Johan Wind
- Video: "RWKV: Reinventing RNNs for the Transformer Era (Paper Explained)" by Yannic Kilcher
Structured State Space Models
Structured State Space Models (SSMs) have become one of the most popular alternatives to Transformers in terms of current research focus, with several notable variants (S4, Hyena, Mamba/S6, Jamba, Mamba-2), but are somewhat notorious for their complexity.
The architecture draws inspiration from classical control theory and linear time-invariant systems, with a number of optimizations to translate from continuous to discrete time, and to avoid dense representations of large matrices. They support both recurrent and convolutional representations, which allows efficiency gains both for training and at inference.
Many variants require carefully-conditioned "hidden state matrix" representations to support "memorization" of context without needing all-pairs attention. SSMs also seem to be becoming more practical at scale, and have recently resulted in breakthrough speed improvements for high-quality text to speech (via Cartesia AI, founded by the inventors of SSMs).
Resources on SSMs
- Tutorial: "The Annotated S4" - comprehensive explainer focused on the S4 paper from which SSMs originated
- Blog post: "A Visual Guide to Mamba and State Space Models" by Maarten Grootendorst - great for intuitions and visuals with slightly less math
- Video: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Paper Explained)" by Yannic Kilcher
Recently, the Mamba authors released their follow-up "Mamba 2" paper, and their accompanying series of blog posts discusses some newly-uncovered connections between SSM representations and linear attention which may be interesting:
Mamba-2 Blog Series
HyperAttention
Somewhat similar to RWKV and SSMs, HyperAttention is another proposal for achieving near-linear scaling for attention-like mechanisms, relying on locality-sensitive hashing (think vector DBs) rather than recurrent representations. I don't see it discussed as much as the others, but it may be worth being aware of nonetheless.
Resources on HyperAttention
- Blog post: "Linear Time Magic: How HyperAttention Optimizes Large Language Models" by Yousra Aoudi
- Video: "HyperAttention Explained" by Tony Shin
Key Takeaways
- Sliding Window Attention provides a simple way to achieve sub-quadratic scaling by limiting attention to recent tokens
- Ring Attention enables full-context interaction with sub-quadratic complexity through message-passing between blocks
- RWKV combines RNN structure with linear attention to achieve constant computation per token as context scales
- Structured State Space Models draw from control theory to create efficient alternatives to Transformers
- HyperAttention uses locality-sensitive hashing to achieve near-linear scaling for attention mechanisms
- These approaches represent a significant research direction for scaling context length beyond what's feasible with standard attention
Beyond Transformers: Other Generative Models
Distribution Modeling
Recalling our first glimpse of language models as simple bigram distributions, the most basic thing you can do in distributional modeling is just count co-occurrence probabilities in your dataset and repeat them as ground truth. This idea can be extended to conditional sampling or classification as "Naive Bayes" (blog post and video), often one of the simplest algorithms covered in introductory machine learning courses.
The next generative model students are often taught is the Gaussian Mixture Model and its Expectation-Maximization algorithm. This blog post and this video give decent overviews; the core idea here is assuming that data distributions can be approximated as a mixture of multivariate Gaussian distributions. GMMs can also be used for clustering if individual groups can be assumed to be approximately Gaussian.
While these methods aren't very effective at representing complex structures like images or language, related ideas will appear as components of some of the more advanced methods we'll see.
Variational Auto-Encoders
Auto-encoders and variational auto-encoders are widely used for learning compressed representations of data distributions, and can also be useful for "denoising" inputs, which will come into play when we discuss diffusion.
Resources on Variational Auto-Encoders
- Textbook chapter: "Autoencoders" in the "Deep Learning" book
- Blog post: "From Autoencoder to Beta-VAE" from Lilian Weng
- Video: "Variational Autoencoders" from Arxiv Insights
- Blog post: "Deep Generative Models" from Prakash Pandey - covers both VAEs and GANs
Generative Adversarial Nets
The basic idea behind Generative Adversarial Networks (GANs) is to simulate a "game" between two neural nets --- the Generator wants to create samples which are indistinguishable from real data by the Discriminator, who wants to identify the generated samples, and both nets are trained continuously until an equilibrium (or desired sample quality) is reached.
Following from von Neumann's minimax theorem for zero-sum games, you basically get a "theorem" promising that GANs succeed at learning distributions, if you assume that gradient descent finds global minimizers and allow both networks to grow arbitrarily large.
Granted, neither of these are literally true in practice, but GANs do tend to be quite effective (although they've fallen out of favor somewhat in recent years, partly due to the instabilities of simultaneous training).
Resources on GANs
- Guide: "Complete Guide to Generative Adversarial Networks" from Paperspace
- Tutorial: "Generative Adversarial Networks (GANs): End-to-End Introduction"
- Textbook chapter: Deep Learning, Ch. 20 - Generative Models (theory-focused)
Conditional GANs
Conditional GANs are where we'll start going from vanilla "distribution learning" to something which more closely resembles interactive generative tools like DALL-E and Midjourney, incorporating text-image multimodality. A key idea is to learn "representations" (in the sense of text embeddings or autoencoders) which are more abstract and can be applied to either text or image inputs.
For example, you could imagine training a vanilla GAN on (image, caption) pairs by embedding the text and concatenating it with an image, which could then learn this joint distribution over images and captions. This implicitly involves learning conditional distributions if part of the input (image or caption) is fixed.
This can be extended to enable automatic captioning (given an image) or image generation (given a caption). There a number of variants on this setup with differing bells and whistles. The VQGAN+CLIP architecture is worth knowing about, as it was a major popular source of early "AI art" generated from input text.
Resources on Conditional GANs
- Blog post: "Implementing Conditional Generative Adversarial Networks" from Paperspace
- Article: "Conditional Generative Adversarial Network — How to Gain Control Over GAN Outputs" by Saul Dobilas
- Tutorial: "The Illustrated VQGAN" by LJ Miranda
- Talk: "Using Deep Learning to Generate Artwork with VQGAN-CLIP" from Paperspace
Normalizing Flows
The aim of normalizing flows is to learn a series of invertible transformations between Gaussian noise and an output distribution, avoiding the need for "simultaneous training" in GANs, and have been popular for generative modeling in a number of domains.
Resources on Normalizing Flows
- Blog post: "Flow-based Deep Generative Models" from Lilian Weng
I haven't personally gone very deep on normalizing flows, but they come up enough that they're probably worth being aware of.
Diffusion Models
One of the central ideas behind diffusion models (like StableDiffusion) is iterative guided application of denoising operations, refining random noise into something that increasingly resembles an image. Diffusion originates from the worlds of stochastic differential equations and statistical physics --- relating to the "Schrodinger bridge" problem and optimal transport for probability distributions --- and a fair amount of math is basically unavoidable if you want to understand the whole picture.
Diffusion models work by gradually adding noise to training data and then learning to reverse this process, effectively learning how to transform random noise into structured data that matches the target distribution.
Resources on Diffusion Models
- Introduction: "A friendly Introduction to Denoising Diffusion Probabilistic Models" by Antony Gitau
- Deep dive: "What are Diffusion Models?" by Lilian Weng
- Code walkthrough: "The Annotated Diffusion Model" from Hugging Face
- Advanced technique: "Fine-tuning Diffusion Models with LoRA" from Hugging Face
Key Takeaways
- Simple models like Naive Bayes and Gaussian Mixture Models form the foundation of generative modeling
- Variational Auto-Encoders learn compressed data representations useful for generation and denoising
- Generative Adversarial Networks create realistic outputs through an adversarial training process
- Conditional GANs extend the GAN framework to enable text-to-image generation
- Normalizing Flows learn invertible transformations between simple distributions and complex ones
- Diffusion Models iteratively denoise random inputs to create structured outputs like images
- Each architecture presents different tradeoffs in training stability, output quality, and controllability
Multimodal Models
Tokenization Beyond Text
The idea of tokenization isn't only relevant to text; audio, images, and video can also be "tokenized" for use in Transformer-style architectures, and there a range of tradeoffs to consider between tokenization and other methods like convolution. The next two sections will look more into visual inputs; this blog post from AssemblyAI touches on a number of relevant topics for audio tokenization and representation in sequence models, for applications like audio generation, text-to-speech, and speech-to-text.
Just as text can be broken into tokens, other modalities like images can be divided into "patches" or audio into "frames" that serve as tokens for multimodal transformers.
VQ-VAE
The VQ-VAE architecture has become quite popular for image generation in recent years, and underlies at least the earlier versions of DALL-E.
Resources on VQ-VAE
- Blog post: "Understanding VQ-VAE (DALL-E Explained Pt. 1)" from the Machine Learning @ Berkeley blog
- Blog post: "How is it so good? (DALL-E Explained Pt. 2)" from Machine Learning @ Berkeley
- Tutorial: "Understanding Vector Quantized Variational Autoencoders (VQ-VAE)" by Shashank Yadav
Vision Transformers
Vision Transformers extend the Transformer architecture to domains like image and video, and have become popular for applications like self-driving cars as well as for multimodal LLMs. There's a nice section in the d2l.ai book about how they work.
Vision Transformers (ViT) adapt the transformer architecture to work with images by splitting them into patches, embedding these patches, and processing them just like tokens in a standard transformer model.
Resources on Vision and Multimodal Models
- Blog post: "Generalized Visual Language Models" by Lilian Weng - discusses a range of different approaches for training multimodal Transformer-style models
- Guide: "Guide to Vision Language Models" from Encord's blog - overviews several architectures for mixing text and vision
- Paper: MM1 from Apple - examines several architecture and data tradeoffs with experimental evidence for Vision Transformers
- Visualization: "Multimodal Neurons in Artificial Neural Networks" from Distill.pub - very fun visualizations of concept representations in multimodal networks
Key Takeaways
- Tokenization concepts extend beyond text to images, audio, and video
- VQ-VAE architectures provide a foundation for image generation, including early versions of DALL-E
- Vision Transformers adapt the transformer architecture to process images by dividing them into patches
- Multimodal models combine different forms of data (text, images, audio) for richer understanding and generation
- Research in multimodal architectures continues to evolve rapidly, with various approaches to combining different data types