← back

Implement a Simple RNN with Backpropagation Through Time (BPTT)

#62 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement a simple Recurrent Neural Network (RNN) with Backpropagation Through Time (BPTT). Given input sequences, weights, and biases, perform a forward pass and then compute gradients by unrolling the network through time and backpropagating the error.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np

def rnn_bptt(input_seq, targets, Wxh, Whh, Why, bh, by, learning_rate=0.01):
    hidden_size = Whh.shape[0]
    T = len(input_seq)

    # Forward pass
    hs = {-1: np.zeros((hidden_size, 1))}
    xs, ys_pred = {}, {}

    for t in range(T):
        xs[t] = input_seq[t].reshape(-1, 1) if input_seq[t].ndim == 1 else input_seq[t]
        hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t - 1]) + bh)
        ys_pred[t] = np.dot(Why, hs[t]) + by

    # Backward pass (BPTT)
    dWxh = np.zeros_like(Wxh)
    dWhh = np.zeros_like(Whh)
    dWhy = np.zeros_like(Why)
    dbh = np.zeros_like(bh)
    dby = np.zeros_like(by)
    dh_next = np.zeros_like(hs[0])

    for t in reversed(range(T)):
        target = targets[t].reshape(-1, 1) if targets[t].ndim == 1 else targets[t]
        dy = ys_pred[t] - target
        dWhy += np.dot(dy, hs[t].T)
        dby += dy

        dh = np.dot(Why.T, dy) + dh_next
        dtanh = (1 - hs[t] ** 2) * dh
        dWxh += np.dot(dtanh, xs[t].T)
        dWhh += np.dot(dtanh, hs[t - 1].T)
        dbh += dtanh
        dh_next = np.dot(Whh.T, dtanh)

    # Gradient update
    for param, dparam in zip([Wxh, Whh, Why, bh, by],
                             [dWxh, dWhh, dWhy, dbh, dby]):
        param -= learning_rate * dparam

    return Wxh, Whh, Why, bh, by

Explanation

  1. Forward pass: For each time step, compute hidden state h_t = tanh(Wxh * x_t + Whh * h_{t-1} + bh) and output y_t = Why * h_t + by.
  2. Backward pass: Compute output error dy, propagate gradients back through the output layer and then through time via the recurrent connection.
  3. The key to BPTT is dh_next: the gradient flowing backward from future time steps through Whh.
  4. The tanh derivative (1 - h^2) gates the gradient through each hidden state.
  5. Accumulate gradients across all time steps, then update weights.

Complexity

  • Time: O(T * n^2) where T is sequence length and n is hidden size
  • Space: O(T * n) for storing hidden states across time steps