← back

Implementing a Simple RNN

#54 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a simple Recurrent Neural Network (RNN) cell. Given an input sequence, compute hidden states at each time step using the recurrence h_t = tanh(X_t @ W_x + h_{t-1} @ W_h + b).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np

class SimpleRNN:
    def __init__(self, input_size, hidden_size, output_size, seed=None):
        if seed is not None:
            np.random.seed(seed)
        self.hidden_size = hidden_size
        self.Wx = np.random.randn(input_size, hidden_size) * 0.01
        self.Wh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.bh = np.zeros((1, hidden_size))
        self.Wy = np.random.randn(hidden_size, output_size) * 0.01
        self.by = np.zeros((1, output_size))

    def forward(self, X):
        X = np.array(X, dtype=np.float64)
        if X.ndim == 2:
            X = X[np.newaxis, :, :]  # (batch, seq_len, input_size)

        batch_size, seq_len, _ = X.shape
        h = np.zeros((batch_size, self.hidden_size))
        hidden_states = []

        for t in range(seq_len):
            h = np.tanh(X[:, t, :] @ self.Wx + h @ self.Wh + self.bh)
            hidden_states.append(h)

        output = h @ self.Wy + self.by
        return output.tolist(), [hs.tolist() for hs in hidden_states]

Explanation

  1. Initialize weight matrices for input-to-hidden (Wx), hidden-to-hidden (Wh), and hidden-to-output (Wy) connections.
  2. Start with a zero hidden state.
  3. At each time step, compute the new hidden state using the tanh activation over the sum of input contribution, previous hidden state contribution, and bias.
  4. After processing all time steps, compute the output from the final hidden state.

Complexity

  • Time: O(T (d_in d_h + d_h^2)) where T is sequence length, d_in is input size, d_h is hidden size
  • Space: O(T * d_h) for storing hidden states