← back

Context Parallelism with Ring Attention for Video Models

#448 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Simulate context parallelism using ring attention for video diffusion models. Given a long sequence (e.g., from video latents), split it across multiple GPUs. Each GPU processes its local chunk of queries while keys/values are passed around in a ring pattern. Compute the per-GPU memory savings, communication volume, and total steps required for full attention.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def ring_attention_context_parallel(
    total_seq_len: int,
    num_gpus: int,
    hidden_dim: int,
    num_heads: int,
    bytes_per_element: int = 2
) -> dict:
    chunk_len = total_seq_len // num_gpus

    per_gpu_q_bytes = chunk_len * hidden_dim * bytes_per_element
    per_gpu_kv_bytes = chunk_len * hidden_dim * bytes_per_element * 2

    full_kv_bytes = total_seq_len * hidden_dim * bytes_per_element * 2

    ring_steps = num_gpus

    per_step_comm_bytes = per_gpu_kv_bytes

    total_comm_bytes = per_step_comm_bytes * (ring_steps - 1)

    memory_per_gpu = per_gpu_q_bytes + per_gpu_kv_bytes

    no_parallel_memory = total_seq_len * hidden_dim * bytes_per_element * 3
    memory_savings = 1 - (memory_per_gpu / no_parallel_memory)

    attn_per_gpu_full = total_seq_len * chunk_len * num_heads * bytes_per_element
    attn_no_parallel = total_seq_len * total_seq_len * num_heads * bytes_per_element

    return {
        "chunk_length": chunk_len,
        "ring_steps": ring_steps,
        "memory_per_gpu_bytes": memory_per_gpu,
        "memory_per_gpu_mb": round(memory_per_gpu / (1024**2), 2),
        "full_kv_memory_bytes": full_kv_bytes,
        "per_step_comm_bytes": per_step_comm_bytes,
        "total_comm_bytes": total_comm_bytes,
        "total_comm_mb": round(total_comm_bytes / (1024**2), 2),
        "memory_savings_fraction": round(memory_savings, 4)
    }

Explanation

  1. The sequence is split evenly into num_gpus chunks. Each GPU holds queries for its local chunk.
  2. In ring attention, KV blocks are passed around in a ring: each GPU computes partial attention with the current KV block, then sends it to the next GPU.
  3. After num_gpus ring steps, every GPU has attended to all KV blocks, achieving full causal or bidirectional attention.
  4. Memory per GPU is reduced from O(S) to O(S/P) for KV storage, where S is total sequence length and P is the GPU count.
  5. Communication volume per GPU is (P-1) KV block transfers. This trades communication for memory, enabling much longer sequences for video models.

Complexity

  • Time: O(S^2 / P) compute per GPU for attention, O(P) communication rounds
  • Space: O(S / P * d) per GPU where d is hidden dimension