The Grand AI Handbook

Inference Optimization

Techniques for improving speed and efficiency of LLM inference.

Here we'll look at a handful of techniques for improving the speed and efficiency of inference from pre-trained Transformer language models, most of which are fairly widely used in practice. It's worth first reading this short Nvidia blog post for a crash course in several of the topics we'll look at (and a number of others).

Parameter Quantization

With the rapid increase in parameter counts for leading LLMs and difficulties (both in cost and availability) in acquiring GPUs to run models on, there’s been a growing interest in quantizing LLM weights to use fewer bits each, which can often yield comparable output quality with a 50-75% (or more) reduction in required memory. Typically this shouldn’t be done naively; Tim Dettmers, one of the pioneers of several modern quantization methods (LLM.int8(), QLoRA, bitsandbytes) has a great blog post for understanding quantization principles, and the need for mixed-precision quantization as it relates to emergent features in large-model training.

Effective quantization can reduce memory requirements by 50-75% while maintaining comparable output quality, making large models accessible on consumer hardware.

Speculative Decoding

The basic idea behind speculative decoding is to speed up inference from a larger model by primarily sampling tokens from a much smaller model and occasionally applying corrections (e.g. every N tokens) from the larger model whenever the output distributions diverge. These batched consistency checks tend to be much faster than sampling N tokens directly, and so there can be large overall speedups if the token sequences from smaller model only diverge periodically.

FlashAttention

Computing attention matrices tends to be a primary bottleneck in inference and training for Transformers, and FlashAttention has become one of the most widely-used techniques for speeding it up. In contrast to some of the techniques we’ll see in Section 7 which approximate attention with a more concise representation (occurring some representation error as a result), FlashAttention is an exact representation whose speedup comes from hardware-aware implementation.

FlashAttention applies tiling and recomputation to decompose the expression of attention matrices, enabling significantly reduced memory I/O and faster wall-clock performance (even while slightly increasing the required FLOPS).

Key-Value Caching and Paged Attention

As noted in the NVIDIA blog referenced above, key-value caching is fairly standard in Transformer implementation matrices to avoid redundant recomputation of attention. This enables a tradeoff between speed and resource utilization, as these matrices are kept in GPU VRAM. While managing this is fairly straightforward for a single “thread” of inference, a number of complexities arise when considering parallel inference or multiple users for a single hosted model instance.

How can you avoid recomputing values for system prompts and few-shot examples? When should you evict cache elements for a user who may or may not want to continue a chat session? PagedAttention addresses these challenges by leveraging ideas from classical paging in operating systems.

PagedAttention and its popular implementation vLLM has become a standard for self-hosted multi-user inference servers.

CPU Offloading

The primary method used for running LLMs either partially or entirely on CPU (vs. GPU) is llama.cpp. This approach is particularly valuable for those without access to high-end GPUs or for deployment in resource-constrained environments.

Key Takeaways

  • Parameter quantization makes large models accessible on consumer hardware with minimal quality loss
  • Speculative decoding accelerates inference by using smaller models to "draft" outputs for larger models
  • FlashAttention significantly speeds up attention computation through hardware-aware implementation
  • Key-value caching avoids redundant computation during autoregressive decoding
  • PagedAttention enables efficient memory management for multi-user inference
  • CPU offloading techniques like llama.cpp allow running models without dedicated GPU hardware