#376 · Deep Learning · Medium
⊣ Solve on deep-ml.comImplement a KV (Key-Value) cache for efficient autoregressive attention. During generation, cache previously computed key and value tensors so that at each new step only the new token's K and V need to be computed, rather than recomputing for the entire sequence.
import numpy as np
class KVCache:
def __init__(self):
self.key_cache = None
self.value_cache = None
def update(self, new_keys: np.ndarray, new_values: np.ndarray):
if self.key_cache is None:
self.key_cache = new_keys
self.value_cache = new_values
else:
self.key_cache = np.concatenate([self.key_cache, new_keys], axis=-2)
self.value_cache = np.concatenate([self.value_cache, new_values], axis=-2)
return self.key_cache, self.value_cache
def cached_attention(
q: np.ndarray, k: np.ndarray, v: np.ndarray,
cache: KVCache, d_k: int
) -> tuple[np.ndarray, KVCache]:
# q: (batch, 1, d_k) for single new token
# k, v: (batch, 1, d_k) for new token's K, V
keys, values = cache.update(k, v)
# Compute attention: q attends to all cached keys
scores = (q @ keys.transpose(0, 2, 1)) / np.sqrt(d_k)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights /= weights.sum(axis=-1, keepdims=True)
output = weights @ values
return output, cache