← back

Implement GRU Cell

#287 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a GRU (Gated Recurrent Unit) cell forward pass. Given input x_t, previous hidden state h_{t-1}, and weight matrices, compute the update gate, reset gate, candidate hidden state, and new hidden state.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np

def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))

def gru_cell(x_t: np.ndarray, h_prev: np.ndarray,
             W_z: np.ndarray, U_z: np.ndarray, b_z: np.ndarray,
             W_r: np.ndarray, U_r: np.ndarray, b_r: np.ndarray,
             W_h: np.ndarray, U_h: np.ndarray, b_h: np.ndarray) -> np.ndarray:
    # Update gate
    z_t = sigmoid(W_z @ x_t + U_z @ h_prev + b_z)

    # Reset gate
    r_t = sigmoid(W_r @ x_t + U_r @ h_prev + b_r)

    # Candidate hidden state
    h_candidate = np.tanh(W_h @ x_t + U_h @ (r_t * h_prev) + b_h)

    # New hidden state
    h_t = z_t * h_prev + (1 - z_t) * h_candidate

    return h_t

Explanation

  1. Update gate z_t controls how much of the previous hidden state to retain vs. how much to replace with the candidate.
  2. Reset gate r_t controls how much of the previous hidden state flows into the candidate computation, allowing the model to forget past information.
  3. Candidate hidden state is computed using the input and the reset-gated previous hidden state, passed through tanh.
  4. Final hidden state is a linear interpolation between the previous state and the candidate, controlled by the update gate.

Complexity

  • Time: O(h^2 + h*d) where h is hidden size and d is input size
  • Space: O(h) for intermediate gate values