← back

Calculate Number of Parameters in Neural Network

#291 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Calculate the total number of trainable parameters in a neural network given its layer specifications. Each layer may be a dense (fully connected) layer, convolutional layer, etc.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def count_parameters(layers: list[dict]) -> int:
    total = 0
    for layer in layers:
        layer_type = layer.get("type", "dense")
        if layer_type == "dense":
            input_size = layer["input_size"]
            output_size = layer["output_size"]
            has_bias = layer.get("bias", True)
            params = input_size * output_size
            if has_bias:
                params += output_size
            total += params
        elif layer_type == "conv2d":
            in_channels = layer["in_channels"]
            out_channels = layer["out_channels"]
            kernel_h = layer["kernel_size"]
            kernel_w = layer.get("kernel_size_w", kernel_h)
            has_bias = layer.get("bias", True)
            params = in_channels * out_channels * kernel_h * kernel_w
            if has_bias:
                params += out_channels
            total += params
        elif layer_type == "batchnorm":
            num_features = layer["num_features"]
            total += 2 * num_features  # gamma and beta
        elif layer_type == "embedding":
            num_embeddings = layer["num_embeddings"]
            embedding_dim = layer["embedding_dim"]
            total += num_embeddings * embedding_dim
    return total

Explanation

  1. Dense layer: weights = input_size * output_size, bias = output_size (if present).
  2. Conv2D layer: weights = in_channels out_channels kernel_h * kernel_w, bias = out_channels.
  3. BatchNorm layer: 2 * num_features (scale gamma and shift beta; running mean/variance are not trainable).
  4. Embedding layer: num_embeddings * embedding_dim (a lookup table).
  5. Sum all parameters across all layers for the total count.

Complexity

  • Time: O(L) where L is the number of layers
  • Space: O(1)