← back

QLoRA: Quantized Low-Rank Adaptation Forward Pass

#223 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the QLoRA (Quantized Low-Rank Adaptation) forward pass. This extends LoRA by using a quantized (e.g., 4-bit NormalFloat) version of the frozen weight matrix to reduce memory, while keeping the low-rank adapters in higher precision.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def qlora_forward(
    x: list[list[float]],
    W_quantized: list[list[float]],
    A: list[list[float]],
    B: list[list[float]],
    alpha: float,
    r: int,
    scale: float = 1.0,
    zero_point: float = 0.0,
) -> list[list[float]]:
    def matmul(a, b):
        rows_a, cols_a = len(a), len(a[0])
        cols_b = len(b[0])
        result = [[0.0] * cols_b for _ in range(rows_a)]
        for i in range(rows_a):
            for k in range(cols_a):
                for j in range(cols_b):
                    result[i][j] += a[i][k] * b[k][j]
        return result

    def add_mat(a, b):
        return [[a[i][j] + b[i][j] for j in range(len(a[0]))] for i in range(len(a))]

    def scale_mat(m, s):
        return [[m[i][j] * s for j in range(len(m[0]))] for i in range(len(m))]

    # Dequantize: W_fp = scale * (W_quantized - zero_point)
    W_dequant = [
        [(W_quantized[i][j] - zero_point) * scale for j in range(len(W_quantized[0]))]
        for i in range(len(W_quantized))
    ]

    # Frozen output with dequantized weights
    Wx = matmul(W_dequant, x)

    # LoRA delta
    Ax = matmul(A, x)
    BAx = matmul(B, Ax)
    delta = scale_mat(BAx, alpha / r)

    return add_mat(Wx, delta)

Explanation

  1. QLoRA stores frozen weights in a low-bit quantized format (e.g., 4-bit NF4).
  2. During the forward pass, the quantized weights are dequantized on-the-fly: W_fp = scale * (W_q - zero_point).
  3. The dequantized weights are used for the base computation, while the LoRA adapters A and B remain in higher precision (e.g., BF16).
  4. This reduces memory by ~4x for frozen weights while maintaining adaptation quality.

Complexity

  • Time: O(d^2 n) for the main matmul, O(d r * n) for the LoRA part
  • Space: O(d^2) dequantized weights (computed on the fly, can be block-wise)