#448 · Deep Learning · Hard
⊣ Solve on deep-ml.comSimulate context parallelism using ring attention for video diffusion models. Given a long sequence (e.g., from video latents), split it across multiple GPUs. Each GPU processes its local chunk of queries while keys/values are passed around in a ring pattern. Compute the per-GPU memory savings, communication volume, and total steps required for full attention.
def ring_attention_context_parallel(
total_seq_len: int,
num_gpus: int,
hidden_dim: int,
num_heads: int,
bytes_per_element: int = 2
) -> dict:
chunk_len = total_seq_len // num_gpus
per_gpu_q_bytes = chunk_len * hidden_dim * bytes_per_element
per_gpu_kv_bytes = chunk_len * hidden_dim * bytes_per_element * 2
full_kv_bytes = total_seq_len * hidden_dim * bytes_per_element * 2
ring_steps = num_gpus
per_step_comm_bytes = per_gpu_kv_bytes
total_comm_bytes = per_step_comm_bytes * (ring_steps - 1)
memory_per_gpu = per_gpu_q_bytes + per_gpu_kv_bytes
no_parallel_memory = total_seq_len * hidden_dim * bytes_per_element * 3
memory_savings = 1 - (memory_per_gpu / no_parallel_memory)
attn_per_gpu_full = total_seq_len * chunk_len * num_heads * bytes_per_element
attn_no_parallel = total_seq_len * total_seq_len * num_heads * bytes_per_element
return {
"chunk_length": chunk_len,
"ring_steps": ring_steps,
"memory_per_gpu_bytes": memory_per_gpu,
"memory_per_gpu_mb": round(memory_per_gpu / (1024**2), 2),
"full_kv_memory_bytes": full_kv_bytes,
"per_step_comm_bytes": per_step_comm_bytes,
"total_comm_bytes": total_comm_bytes,
"total_comm_mb": round(total_comm_bytes / (1024**2), 2),
"memory_savings_fraction": round(memory_savings, 4)
}num_gpus chunks. Each GPU holds queries for its local chunk.num_gpus ring steps, every GPU has attended to all KV blocks, achieving full causal or bidirectional attention.(P-1) KV block transfers. This trades communication for memory, enabling much longer sequences for video models.