← back

Flash Attention v1 - Forward Pass

#208 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement the forward pass of Flash Attention v1. Flash Attention computes exact attention while minimizing memory reads/writes by processing blocks of the query, key, and value matrices using a tiling strategy. Compute the output without materializing the full attention matrix.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np

def flash_attention_forward(Q, K, V, block_size=32):
    """
    Q, K, V: (N, d) matrices
    Returns: O (N, d) attention output
    """
    N, d = Q.shape
    O = np.zeros((N, d), dtype=float)
    l = np.zeros(N, dtype=float)  # running sum of exp
    m = np.full(N, -np.inf)      # running max

    n_blocks = (N + block_size - 1) // block_size

    for j in range(n_blocks):
        kv_start = j * block_size
        kv_end = min(kv_start + block_size, N)
        Kj = K[kv_start:kv_end]   # (Bk, d)
        Vj = V[kv_start:kv_end]   # (Bk, d)

        for i in range(n_blocks):
            q_start = i * block_size
            q_end = min(q_start + block_size, N)
            Qi = Q[q_start:q_end]  # (Bq, d)

            # Compute block attention scores
            S = Qi @ Kj.T / np.sqrt(d)  # (Bq, Bk)

            # Current block max
            block_max = np.max(S, axis=1)  # (Bq,)
            m_prev = m[q_start:q_end].copy()
            m_new = np.maximum(m_prev, block_max)

            # Stable exp
            exp_s = np.exp(S - m_new[:, None])  # (Bq, Bk)
            exp_correction = np.exp(m_prev - m_new)

            # Update running sums
            l_prev = l[q_start:q_end].copy()
            l_new = exp_correction * l_prev + np.sum(exp_s, axis=1)

            # Update output: rescale old output and add new contribution
            O[q_start:q_end] = (
                (exp_correction[:, None] * O[q_start:q_end] +
                 exp_s @ Vj) /
                (l_new[:, None] + 1e-12)
            ) * l_new[:, None]  # keep unnormalized for accumulation

            # Actually store unnormalized
            O[q_start:q_end] = (
                exp_correction[:, None] * l_prev[:, None] /
                (l_new[:, None] + 1e-12) * O[q_start:q_end] /
                (l_prev[:, None] + 1e-12) +
                exp_s @ Vj / (l_new[:, None] + 1e-12)
            ) if False else O[q_start:q_end]

            # Simplified correct accumulation
            O[q_start:q_end] = (
                exp_correction[:, None] * O[q_start:q_end] +
                exp_s @ Vj
            )
            l[q_start:q_end] = l_new
            m[q_start:q_end] = m_new

    # Final normalization
    O = O / (l[:, None] + 1e-12)
    return O

def flash_attention_v1(Q, K, V, block_size=32):
    """Simplified, correct Flash Attention v1 forward pass."""
    N, d = Q.shape
    scale = 1.0 / np.sqrt(d)
    O = np.zeros((N, d), dtype=np.float64)
    l = np.zeros(N, dtype=np.float64)
    m = np.full(N, -np.inf, dtype=np.float64)

    n_kv_blocks = (N + block_size - 1) // block_size

    for j in range(n_kv_blocks):
        js = j * block_size
        je = min(js + block_size, N)
        Kj = K[js:je]
        Vj = V[js:je]

        for i in range((N + block_size - 1) // block_size):
            qs = i * block_size
            qe = min(qs + block_size, N)
            Qi = Q[qs:qe]

            S = Qi @ Kj.T * scale                  # (Bq, Bk)
            m_block = np.max(S, axis=1)             # (Bq,)
            m_old = m[qs:qe].copy()
            m_new = np.maximum(m_old, m_block)

            correction = np.exp(m_old - m_new)
            P = np.exp(S - m_new[:, None])

            l[qs:qe] = correction * l[qs:qe] + P.sum(axis=1)
            O[qs:qe] = correction[:, None] * O[qs:qe] + P @ Vj
            m[qs:qe] = m_new

    O /= l[:, None] + 1e-12
    return O

Explanation

  1. Tiling: Process Q and K/V in blocks to keep memory footprint small.
  2. Online softmax: Track running max m and running sum of exponentials l across KV blocks. When a new block has a larger max, rescale previous accumulators by exp(m_old - m_new).
  3. Accumulate unnormalized output: O = correction * O + exp(S - m_new) @ V. This keeps O in an unnormalized state during accumulation.
  4. Final normalization: Divide O by l to get the properly normalized attention output.
  5. The flash_attention_v1 function is the clean, correct implementation. The result is mathematically identical to standard attention softmax(QK^T / sqrt(d)) V.

Complexity

  • Time: O(N^2 * d), same as standard attention
  • Space: O(N * d + B^2) instead of O(N^2) -- avoids materializing the full N x N attention matrix