← back

Implement Relativistic Critic Rewards for Adversarial Reasoning

#268 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Relativistic Critic Rewards for adversarial reasoning. In a standard GAN the critic scores samples absolutely, but a relativistic critic scores how much more real a real sample is compared to a fake one (and vice versa). Compute the relativistic losses for both the generator and discriminator.

Solution

Instead of D(x) alone, use D(x_real) - mean(D(x_fake)) for real samples and D(x_fake) - mean(D(x_real)) for fake samples, then apply sigmoid cross-entropy.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import math

def sigmoid(x: float) -> float:
    if x >= 0:
        return 1.0 / (1.0 + math.exp(-x))
    else:
        ex = math.exp(x)
        return ex / (1.0 + ex)

def relativistic_critic_rewards(
    real_scores: list[float],
    fake_scores: list[float],
) -> dict:
    n_real = len(real_scores)
    n_fake = len(fake_scores)

    mean_real = sum(real_scores) / n_real
    mean_fake = sum(fake_scores) / n_fake

    # Discriminator loss: -[log(sig(D(real)-mean(D(fake)))) + log(1-sig(D(fake)-mean(D(real))))]
    d_loss = 0.0
    for r in real_scores:
        rel = r - mean_fake
        d_loss -= math.log(max(sigmoid(rel), 1e-12))
    for f in fake_scores:
        rel = f - mean_real
        d_loss -= math.log(max(1.0 - sigmoid(rel), 1e-12))
    d_loss /= (n_real + n_fake)

    # Generator loss: -[log(sig(D(fake)-mean(D(real)))) + log(1-sig(D(real)-mean(D(fake))))]
    g_loss = 0.0
    for f in fake_scores:
        rel = f - mean_real
        g_loss -= math.log(max(sigmoid(rel), 1e-12))
    for r in real_scores:
        rel = r - mean_fake
        g_loss -= math.log(max(1.0 - sigmoid(rel), 1e-12))
    g_loss /= (n_real + n_fake)

    # Relativistic outputs for each sample
    real_relative = [round(r - mean_fake, 6) for r in real_scores]
    fake_relative = [round(f - mean_real, 6) for f in fake_scores]

    return {
        "d_loss": round(d_loss, 6),
        "g_loss": round(g_loss, 6),
        "real_relative": real_relative,
        "fake_relative": fake_relative,
    }

Explanation

  1. Compute the mean critic score for real and fake samples separately.
  2. Relativistic real score: D(x_real) - mean(D(x_fake)) — how much more real a real sample is compared to the average fake.
  3. Relativistic fake score: D(x_fake) - mean(D(x_real)) — how much more real a fake sample is compared to the average real.
  4. The discriminator tries to maximize the relativistic real scores and minimize the relativistic fake scores.
  5. The generator's loss is symmetric — it tries to make fake samples appear more real than real ones on average.

Complexity

  • Time: O(n) where n is the total number of samples
  • Space: O(n) for storing relative scores