← back

Implement a Simple CNN Training Function with Backpropagation

#130 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement a simple CNN training function with forward pass and backpropagation. The CNN should have a convolutional layer followed by a fully connected layer, trained on a small dataset using MSE loss.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np

def conv2d_forward(x, kernel, bias, stride=1, pad=0):
    if pad > 0:
        x = np.pad(x, ((0,0),(0,0),(pad,pad),(pad,pad)), mode='constant')
    N, C_in, H, W = x.shape
    C_out, C_in_k, kH, kW = kernel.shape
    out_h = (H - kH) // stride + 1
    out_w = (W - kW) // stride + 1
    out = np.zeros((N, C_out, out_h, out_w))
    for n in range(N):
        for co in range(C_out):
            for i in range(out_h):
                for j in range(out_w):
                    h_start = i * stride
                    w_start = j * stride
                    receptive = x[n, :, h_start:h_start+kH, w_start:w_start+kW]
                    out[n, co, i, j] = np.sum(receptive * kernel[co]) + bias[co]
    return out

def relu(x):
    return np.maximum(0, x)

def relu_backward(dout, x):
    return dout * (x > 0)

def train_cnn(X, y, epochs=100, lr=0.01):
    N, C, H, W = X.shape
    # Conv layer: 1 filter of size 3x3
    n_filters = 1
    kH, kW = 3, 3
    kernel = np.random.randn(n_filters, C, kH, kW) * 0.1
    bias_conv = np.zeros(n_filters)

    out_h = H - kH + 1
    out_w = W - kW + 1
    fc_input_size = n_filters * out_h * out_w
    output_size = y.shape[1] if y.ndim > 1 else 1

    W_fc = np.random.randn(fc_input_size, output_size) * 0.1
    b_fc = np.zeros(output_size)

    for epoch in range(epochs):
        # Forward
        conv_out_raw = conv2d_forward(X, kernel, bias_conv)
        conv_out = relu(conv_out_raw)
        flat = conv_out.reshape(N, -1)
        logits = flat @ W_fc + b_fc

        # MSE loss
        target = y.reshape(N, output_size)
        loss = np.mean((logits - target) ** 2)

        # Backward
        d_logits = 2.0 * (logits - target) / (N * output_size)

        # FC backward
        dW_fc = flat.T @ d_logits
        db_fc = np.sum(d_logits, axis=0)
        d_flat = d_logits @ W_fc.T

        # Reshape and ReLU backward
        d_conv_out = d_flat.reshape(conv_out.shape)
        d_conv_raw = relu_backward(d_conv_out, conv_out_raw)

        # Conv backward
        d_kernel = np.zeros_like(kernel)
        d_bias_conv = np.sum(d_conv_raw, axis=(0, 2, 3))
        for n in range(N):
            for co in range(n_filters):
                for i in range(out_h):
                    for j in range(out_w):
                        h_s = i
                        w_s = j
                        d_kernel[co] += d_conv_raw[n, co, i, j] * X[n, :, h_s:h_s+kH, w_s:w_s+kW]

        # Update
        kernel -= lr * d_kernel
        bias_conv -= lr * d_bias_conv
        W_fc -= lr * dW_fc
        b_fc -= lr * db_fc

    return kernel, bias_conv, W_fc, b_fc, loss

Explanation

  1. Forward pass: Input goes through a 2D convolution, ReLU activation, flattening, and a fully connected layer.
  2. Loss: Mean Squared Error between predictions and targets.
  3. Backward pass: Compute gradients by chain rule. The FC gradient is straightforward matrix math. The conv gradient loops over spatial positions, accumulating the product of upstream gradients and input patches.
  4. Update: Simple SGD parameter update with learning rate.

Complexity

  • Time: O(epochs N C_out C_in out_h out_w kH * kW) per epoch for conv forward/backward
  • Space: O(N C_out out_h * out_w) for intermediate activations