← back

KV Cache for Efficient Autoregressive Attention

#376 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a KV (Key-Value) cache for efficient autoregressive attention. During generation, cache previously computed key and value tensors so that at each new step only the new token's K and V need to be computed, rather than recomputing for the entire sequence.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np

class KVCache:
    def __init__(self):
        self.key_cache = None
        self.value_cache = None

    def update(self, new_keys: np.ndarray, new_values: np.ndarray):
        if self.key_cache is None:
            self.key_cache = new_keys
            self.value_cache = new_values
        else:
            self.key_cache = np.concatenate([self.key_cache, new_keys], axis=-2)
            self.value_cache = np.concatenate([self.value_cache, new_values], axis=-2)
        return self.key_cache, self.value_cache

def cached_attention(
    q: np.ndarray, k: np.ndarray, v: np.ndarray,
    cache: KVCache, d_k: int
) -> tuple[np.ndarray, KVCache]:
    # q: (batch, 1, d_k) for single new token
    # k, v: (batch, 1, d_k) for new token's K, V
    keys, values = cache.update(k, v)
    # Compute attention: q attends to all cached keys
    scores = (q @ keys.transpose(0, 2, 1)) / np.sqrt(d_k)
    weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
    weights /= weights.sum(axis=-1, keepdims=True)
    output = weights @ values
    return output, cache

Explanation

  1. The KV cache stores all previously computed key and value tensors.
  2. At each generation step, compute K and V only for the new token and append to the cache.
  3. The query for the new token attends to all cached keys (the full history).
  4. This avoids recomputing K and V for all previous tokens, reducing per-step complexity from O(T^2) to O(T).

Complexity

  • Time: O(T * d) per generation step where T is the current sequence length
  • Space: O(T * d) for the cached keys and values