← back

Train a Simple GAN on 1D Gaussian Data

#174 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Train a simple Generative Adversarial Network (GAN) on 1D Gaussian data. Implement a generator that learns to produce samples from a target distribution and a discriminator that distinguishes real from generated samples.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np

def train_gan_1d(real_mean: float = 5.0, real_std: float = 1.0,
                 n_epochs: int = 5000, batch_size: int = 64,
                 lr: float = 0.01):
    g_w = np.random.randn() * 0.1
    g_b = np.random.randn() * 0.1
    d_w = np.random.randn() * 0.1
    d_b = np.random.randn() * 0.1

    def sigmoid(x):
        return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))

    def generator(z):
        return g_w * z + g_b

    def discriminator(x):
        return sigmoid(d_w * x + d_b)

    for epoch in range(n_epochs):
        real_data = np.random.normal(real_mean, real_std, batch_size)
        z = np.random.randn(batch_size)
        fake_data = generator(z)

        d_real = discriminator(real_data)
        d_fake = discriminator(fake_data)

        d_loss_real = -np.mean(np.log(d_real + 1e-8))
        d_loss_fake = -np.mean(np.log(1 - d_fake + 1e-8))

        dd_real_dw = np.mean(real_data * d_real * (1 - d_real))
        dd_real_db = np.mean(d_real * (1 - d_real))
        dd_fake_dw = -np.mean(fake_data * d_fake * (1 - d_fake))
        dd_fake_db = -np.mean(d_fake * (1 - d_fake))

        d_w += lr * (dd_real_dw + dd_fake_dw)
        d_b += lr * (dd_real_db + dd_fake_db)

        z = np.random.randn(batch_size)
        fake_data = generator(z)
        d_fake = discriminator(fake_data)

        dg_dw = np.mean(d_w * z * d_fake * (1 - d_fake)) / (np.mean(d_fake) + 1e-8)
        dg_db = np.mean(d_w * d_fake * (1 - d_fake)) / (np.mean(d_fake) + 1e-8)
        g_w += lr * dg_dw
        g_b += lr * dg_db

    return g_w, g_b, d_w, d_b

Explanation

  1. Generator: A linear function g(z) = w*z + b that maps noise z ~ N(0,1) to the target distribution.
  2. Discriminator: A sigmoid of a linear function that outputs the probability an input is real.
  3. Discriminator training: Maximize log(D(real)) + log(1 - D(fake)) to correctly classify real vs fake.
  4. Generator training: Maximize log(D(G(z))) to fool the discriminator.
  5. After training, the generator should approximate g(z) = real_std * z + real_mean.

Complexity

  • Time: O(n_epochs * batch_size) for training
  • Space: O(batch_size) for storing samples per batch