← back

Implement Efficient Sparse Window Attention

#131 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement efficient sparse window attention, where each token attends only to tokens within a fixed-size local window rather than all tokens. This reduces the quadratic complexity of standard self-attention to linear.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np

def sparse_window_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, window_size: int) -> np.ndarray:
    seq_len, d_k = Q.shape
    half_w = window_size // 2
    output = np.zeros_like(V)

    for i in range(seq_len):
        start = max(0, i - half_w)
        end = min(seq_len, i + half_w + 1)

        q_i = Q[i]  # (d_k,)
        K_window = K[start:end]  # (w, d_k)
        V_window = V[start:end]  # (w, d_v)

        # Scaled dot-product attention within window
        scores = K_window @ q_i / np.sqrt(d_k)  # (w,)
        scores -= np.max(scores)  # numerical stability
        weights = np.exp(scores)
        weights /= np.sum(weights)

        output[i] = weights @ V_window

    return output

def sparse_window_attention_batched(Q: np.ndarray, K: np.ndarray, V: np.ndarray, window_size: int) -> np.ndarray:
    # Q, K, V: (batch, seq_len, d_k)
    batch_size = Q.shape[0]
    results = []
    for b in range(batch_size):
        results.append(sparse_window_attention(Q[b], K[b], V[b], window_size))
    return np.stack(results)

Explanation

  1. For each query position i, define a local window [i - w/2, i + w/2] clamped to sequence boundaries.
  2. Compute scaled dot-product attention only within this window.
  3. Apply softmax to the local attention scores and compute the weighted sum of values.
  4. Each token attends to at most window_size tokens instead of all seq_len tokens.

Complexity

  • Time: O(n w d) where n = sequence length, w = window size, d = dimension
  • Space: O(n * d) for the output; O(w) per query for attention scores