#420 · Inference · Medium
⊣ Solve on deep-ml.comImplement the patchification step used in Diffusion Transformers (DiT). Given a latent image tensor of shape (C, H, W), divide it into non-overlapping patches of size (p, p), and flatten each patch into a 1D token vector. Return the sequence of patch tokens suitable for input to a Transformer.
def patchify(latent: list[list[list[float]]], patch_size: int) -> list[list[float]]:
C = len(latent)
H = len(latent[0])
W = len(latent[0][0])
assert H % patch_size == 0 and W % patch_size == 0
n_patches_h = H // patch_size
n_patches_w = W // patch_size
tokens = []
for ph in range(n_patches_h):
for pw in range(n_patches_w):
token = []
for c in range(C):
for i in range(patch_size):
for j in range(patch_size):
row = ph * patch_size + i
col = pw * patch_size + j
token.append(latent[c][row][col])
tokens.append(token)
return tokens # shape: (n_patches, C * patch_size * patch_size)
def unpatchify(tokens: list[list[float]], C: int, H: int, W: int, patch_size: int) -> list[list[list[float]]]:
n_patches_h = H // patch_size
n_patches_w = W // patch_size
latent = [[[0.0] * W for _ in range(H)] for _ in range(C)]
for idx, token in enumerate(tokens):
ph = idx // n_patches_w
pw = idx % n_patches_w
ptr = 0
for c in range(C):
for i in range(patch_size):
for j in range(patch_size):
row = ph * patch_size + i
col = pw * patch_size + j
latent[c][row][col] = token[ptr]
ptr += 1
return latent(H/p) x (W/p) non-overlapping patches.C * p * p values (all channels, all spatial positions within the patch) into a single flat vector (token).(H/p) * (W/p) tokens, each of dimension C * p * p.