← back

Latent Space Patchification for Diffusion Transformers

#420 · Inference · Medium

⊣ Solve on deep-ml.com

Problem

Implement the patchification step used in Diffusion Transformers (DiT). Given a latent image tensor of shape (C, H, W), divide it into non-overlapping patches of size (p, p), and flatten each patch into a 1D token vector. Return the sequence of patch tokens suitable for input to a Transformer.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def patchify(latent: list[list[list[float]]], patch_size: int) -> list[list[float]]:
    C = len(latent)
    H = len(latent[0])
    W = len(latent[0][0])
    assert H % patch_size == 0 and W % patch_size == 0

    n_patches_h = H // patch_size
    n_patches_w = W // patch_size
    tokens = []

    for ph in range(n_patches_h):
        for pw in range(n_patches_w):
            token = []
            for c in range(C):
                for i in range(patch_size):
                    for j in range(patch_size):
                        row = ph * patch_size + i
                        col = pw * patch_size + j
                        token.append(latent[c][row][col])
            tokens.append(token)

    return tokens  # shape: (n_patches, C * patch_size * patch_size)


def unpatchify(tokens: list[list[float]], C: int, H: int, W: int, patch_size: int) -> list[list[list[float]]]:
    n_patches_h = H // patch_size
    n_patches_w = W // patch_size
    latent = [[[0.0] * W for _ in range(H)] for _ in range(C)]

    for idx, token in enumerate(tokens):
        ph = idx // n_patches_w
        pw = idx % n_patches_w
        ptr = 0
        for c in range(C):
            for i in range(patch_size):
                for j in range(patch_size):
                    row = ph * patch_size + i
                    col = pw * patch_size + j
                    latent[c][row][col] = token[ptr]
                    ptr += 1
    return latent

Explanation

  1. Patchification divides the spatial dimensions (H, W) into a grid of (H/p) x (W/p) non-overlapping patches.
  2. Each patch collects C * p * p values (all channels, all spatial positions within the patch) into a single flat vector (token).
  3. The resulting sequence has (H/p) * (W/p) tokens, each of dimension C * p * p.
  4. In Diffusion Transformers, this operates on the VAE latent space (e.g., C=4, H=W=32 for a 256x256 image), not on raw pixels.
  5. Unpatchify reverses the process to reconstruct the spatial tensor from tokens.

Complexity

  • Time: O(C H W) to iterate over every element once
  • Space: O(C H W) for the output tokens