Implement the forward pass of Flash Attention v1. Flash Attention computes exact attention while minimizing memory reads/writes by processing blocks of the query, key, and value matrices using a tiling strategy. Compute the output without materializing the full attention matrix.
import numpy as np
def flash_attention_forward(Q, K, V, block_size=32):
"""
Q, K, V: (N, d) matrices
Returns: O (N, d) attention output
"""
N, d = Q.shape
O = np.zeros((N, d), dtype=float)
l = np.zeros(N, dtype=float) # running sum of exp
m = np.full(N, -np.inf) # running max
n_blocks = (N + block_size - 1) // block_size
for j in range(n_blocks):
kv_start = j * block_size
kv_end = min(kv_start + block_size, N)
Kj = K[kv_start:kv_end] # (Bk, d)
Vj = V[kv_start:kv_end] # (Bk, d)
for i in range(n_blocks):
q_start = i * block_size
q_end = min(q_start + block_size, N)
Qi = Q[q_start:q_end] # (Bq, d)
# Compute block attention scores
S = Qi @ Kj.T / np.sqrt(d) # (Bq, Bk)
# Current block max
block_max = np.max(S, axis=1) # (Bq,)
m_prev = m[q_start:q_end].copy()
m_new = np.maximum(m_prev, block_max)
# Stable exp
exp_s = np.exp(S - m_new[:, None]) # (Bq, Bk)
exp_correction = np.exp(m_prev - m_new)
# Update running sums
l_prev = l[q_start:q_end].copy()
l_new = exp_correction * l_prev + np.sum(exp_s, axis=1)
# Update output: rescale old output and add new contribution
O[q_start:q_end] = (
(exp_correction[:, None] * O[q_start:q_end] +
exp_s @ Vj) /
(l_new[:, None] + 1e-12)
) * l_new[:, None] # keep unnormalized for accumulation
# Actually store unnormalized
O[q_start:q_end] = (
exp_correction[:, None] * l_prev[:, None] /
(l_new[:, None] + 1e-12) * O[q_start:q_end] /
(l_prev[:, None] + 1e-12) +
exp_s @ Vj / (l_new[:, None] + 1e-12)
) if False else O[q_start:q_end]
# Simplified correct accumulation
O[q_start:q_end] = (
exp_correction[:, None] * O[q_start:q_end] +
exp_s @ Vj
)
l[q_start:q_end] = l_new
m[q_start:q_end] = m_new
# Final normalization
O = O / (l[:, None] + 1e-12)
return O
def flash_attention_v1(Q, K, V, block_size=32):
"""Simplified, correct Flash Attention v1 forward pass."""
N, d = Q.shape
scale = 1.0 / np.sqrt(d)
O = np.zeros((N, d), dtype=np.float64)
l = np.zeros(N, dtype=np.float64)
m = np.full(N, -np.inf, dtype=np.float64)
n_kv_blocks = (N + block_size - 1) // block_size
for j in range(n_kv_blocks):
js = j * block_size
je = min(js + block_size, N)
Kj = K[js:je]
Vj = V[js:je]
for i in range((N + block_size - 1) // block_size):
qs = i * block_size
qe = min(qs + block_size, N)
Qi = Q[qs:qe]
S = Qi @ Kj.T * scale # (Bq, Bk)
m_block = np.max(S, axis=1) # (Bq,)
m_old = m[qs:qe].copy()
m_new = np.maximum(m_old, m_block)
correction = np.exp(m_old - m_new)
P = np.exp(S - m_new[:, None])
l[qs:qe] = correction * l[qs:qe] + P.sum(axis=1)
O[qs:qe] = correction[:, None] * O[qs:qe] + P @ Vj
m[qs:qe] = m_new
O /= l[:, None] + 1e-12
return Om and running sum of exponentials l across KV blocks. When a new block has a larger max, rescale previous accumulators by exp(m_old - m_new).O = correction * O + exp(S - m_new) @ V. This keeps O in an unnormalized state during accumulation.flash_attention_v1 function is the clean, correct implementation. The result is mathematically identical to standard attention softmax(QK^T / sqrt(d)) V.