https://dipkumar.dev/posts/gpt-kvcache/
February 12, 2023 | Estimated Reading Time: 8 min | Author: Dipkumar Patel | Suggest Changes
The common optimization trick for speeding up transformer inference is KV caching 1 2. This technique is so prominent that huggingface library has use_cache
flag is enabled by default 6. A few days ago, I read an awesome blog post on GPT in 60 Lines of NumPy. So, i thought, why not extend it to use the KV cache technique? So, let’s roll up our sleeves and start working on it. Before you read further, the blog assumes you have background on transformers; if you don’t, then read this blog post. It’s awesome, and you will learn a lot from it.
First, let’s understand a few things about GPT code.
def gpt(inputs: list[int]) -> list[list[float]]:
# inputs has shape [n_seq]
# output has shape [n_seq, n_vocab]
output = # beep boop neural network magic
return output
We can deduce from the input-output signature that we can provide arbitrary long input and receive output of the same length, with each element of the output indicating the probability of the next token. So, I can just give a single token as input and get the probability of next token. It should just work, right ?
Modifying the code of picoGPT to just give the input of the last single token and get the probability of the next token.
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs[-1:], **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs = np.append(inputs, [next_id]) # append prediction to input
We are providing inputs[-1:]
as input (single token) to the model. So, we are just passing a single token as input. Let’s see what happens.
the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the
I didn’t work. Because the main magic is in the attention, in order to have good prediction of next tokens we need to provide all previous tokens. Although in practice, we do have limited memory and compute which forces us to provide context upto last N tokens. for example, chagpt has context upto 4096. In summary, We can’t just pass a single token and get very good prediction of next token. This makes attention have quadratic complexity.
But, if we look at the architecture of GPT, we can see that we only interact with previous tokens in the attention block, all other layers, such as the embedding layer, the feed forward layer, the layer norm, etc., don’t care about previous tokens. So, what if we can cache the input of the attention block for all previous tokens and pass it during inference? We don’t have to pass all these tokens again and again. We can just pass the last token and get the probability of the next token.
The input of the attention block is q, k, v and mask. We can try to cache q, k, and v for all previous tokens. But, let’s think about what really matters for us. We only need k and v of the previous tokens to perform attention on the current input token because we are only passing one token as input. See the image below for a visual representation of what I mean.
def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
attention with kvcache
So, we need to calculate new_k and new_v for current input token. Append it to the existing cache and pass it to attention block for further processing.
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
if kvcache:
# qkv
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
old_k, old_v = kvcache
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
qkv = [new_q, k, v]
There is one more thing we need to take care of is causal mask. When we pass single token we would like it to attend to all previous tokens.