← back

Sliding Window Attention

#388 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement sliding window attention (as used in models like Longformer). Instead of full quadratic attention over all tokens, each token only attends to a fixed-size local window of neighboring tokens, reducing memory from O(n^2) to O(n * w).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np

def sliding_window_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, window_size: int) -> np.ndarray:
    # Q, K, V: shape (seq_len, d_k)
    seq_len, d_k = Q.shape
    output = np.zeros_like(V)

    for i in range(seq_len):
        # Define window bounds
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)

        # Attend only within window
        q_i = Q[i]  # (d_k,)
        K_window = K[start:end]  # (w, d_k)
        V_window = V[start:end]  # (w, d_k)

        scores = K_window @ q_i / np.sqrt(d_k)  # (w,)
        weights = np.exp(scores - np.max(scores))
        weights = weights / weights.sum()

        output[i] = weights @ V_window

    return output

Explanation

  1. For each query position i, define a local window of size w centered on that position.
  2. Compute attention scores only between the query and keys within that window.
  3. Apply softmax to get attention weights, then compute the weighted sum of values in the window.
  4. This restricts each token's receptive field to its local neighborhood, which is efficient for long sequences.

Complexity

  • Time: O(n w d) where n is sequence length, w is window size, d is dimension
  • Space: O(n * d) for the output