#131 · Machine Learning · Medium
⊣ Solve on deep-ml.comImplement efficient sparse window attention, where each token attends only to tokens within a fixed-size local window rather than all tokens. This reduces the quadratic complexity of standard self-attention to linear.
import numpy as np
def sparse_window_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, window_size: int) -> np.ndarray:
seq_len, d_k = Q.shape
half_w = window_size // 2
output = np.zeros_like(V)
for i in range(seq_len):
start = max(0, i - half_w)
end = min(seq_len, i + half_w + 1)
q_i = Q[i] # (d_k,)
K_window = K[start:end] # (w, d_k)
V_window = V[start:end] # (w, d_v)
# Scaled dot-product attention within window
scores = K_window @ q_i / np.sqrt(d_k) # (w,)
scores -= np.max(scores) # numerical stability
weights = np.exp(scores)
weights /= np.sum(weights)
output[i] = weights @ V_window
return output
def sparse_window_attention_batched(Q: np.ndarray, K: np.ndarray, V: np.ndarray, window_size: int) -> np.ndarray:
# Q, K, V: (batch, seq_len, d_k)
batch_size = Q.shape[0]
results = []
for b in range(batch_size):
results.append(sparse_window_attention(Q[b], K[b], V[b], window_size))
return np.stack(results)[i - w/2, i + w/2] clamped to sequence boundaries.window_size tokens instead of all seq_len tokens.