https://www.anyscale.com/blog/continuous-batching-llm-inference
Due to the large GPU memory footprint and compute cost of LLMs, serving dominates the compute cost for most real world applications. ML engineers often treat LLMs like "black boxes" that can only be optimized with internal changes such as quantization and custom CUDA kernels. However, this is not entirely the case. Because LLMs iteratively generate their output, and because LLM inference is often memory and not compute bound, there are surprising system-level batching optimizations that make 10x or more differences in real-world workloads.
One recent such proposed optimization is continuous batching, also known as dynamic batching, or batching with iteration-level scheduling. We wanted to see how this optimization performs. We will get into details below, including how we simulate a production workload, but to summarize our findings:
You can try out continuous batching today: see this example to run vLLM on Ray Serve.
The remainder of this blog is structured as follows:
There is a lot to know about LLM inference, and we refer users to Efficient Inference on a Single GPU and Optimization story: Bloom inference for more detail. However, at a high level, LLM inference is pretty straightforward.
For each request:
This is an iterative process. You get one additional completion token for each new forward pass of the model. For example, suppose you prompt with a sentence "What is the capital of California: ", it would take ten forward pass iterations to get back the full response of ["S", "a", "c", "r", “a”, "m", "e", "n", "t", "o"]. This example simplifies things a little bit because in actuality tokens do not map 1:1 to ASCII characters (a popular token encoding technique is Byte-Pair Encoding which is beyond the scope of this blog post), but the iterative nature of generation is the same regardless of how you tokenize your sequences.
Simplified LLM inference. This toy example shows a hypothetical model which supports a maximum sequence length of 8 tokens (T1, T2, …, T8). Starting from the prompt tokens (yellow), the iterative process generates a single token at a time (blue). Once the model generates an end-of-sequence token (red), the generation loop stops. This example shows a batch of only one input sequence, so the batch size is 1.