← back

Batch Iterator for Dataset

#30 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement a batch iterator for a dataset. Given a 2D NumPy array X and an integer batch_size, yield successive batches of rows from X. If the dataset size is not evenly divisible by the batch size, the last batch should contain the remaining rows.

Solution

1
2
3
4
5
6
import numpy as np

def batch_iterator(X, batch_size=64):
    n_samples = X.shape[0]
    for i in range(0, n_samples, batch_size):
        yield X[i:i + batch_size]

Explanation

  1. Get the total number of samples from the first dimension of the array.
  2. Use range with step batch_size to iterate over starting indices.
  3. Slice the array from index i to i + batch_size. Python slicing naturally handles the last batch if it is smaller than batch_size.

Complexity

  • Time: O(n) across all yields, where n is the number of rows
  • Space: O(batch_size) per yielded batch (views, not copies, in NumPy)