← back

Chain Rule for Composite Functions

#214 · Calculus · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Chain Rule for computing derivatives of composite functions. Given a composition f(g(x)), compute the derivative f'(g(x)) * g'(x) numerically and symbolically for simple cases.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import math

def chain_rule_numerical(f, g, x, epsilon=1e-7):
    """
    Compute d/dx f(g(x)) using the chain rule numerically.
    f, g: callable functions
    """
    # g'(x)
    g_prime = (g(x + epsilon) - g(x - epsilon)) / (2 * epsilon)

    # f'(g(x))
    gx = g(x)
    f_prime_at_gx = (f(gx + epsilon) - f(gx - epsilon)) / (2 * epsilon)

    # Chain rule: f'(g(x)) * g'(x)
    return f_prime_at_gx * g_prime

def chain_rule_multi(funcs: list, x: float, epsilon: float = 1e-7):
    """
    Compute derivative of composed functions f_n(f_{n-1}(...f_1(x)...)).
    Uses the generalized chain rule.
    """
    n = len(funcs)
    # Forward pass: compute all intermediate values
    values = [x]
    for f in funcs:
        values.append(f(values[-1]))

    # Backward pass: accumulate derivatives
    derivative = 1.0
    for i in range(n):
        v = values[i]
        f = funcs[i]
        f_prime = (f(v + epsilon) - f(v - epsilon)) / (2 * epsilon)
        derivative *= f_prime

    return derivative

def verify_chain_rule(f, g, x, epsilon=1e-7):
    """Verify chain rule by comparing with direct numerical differentiation."""
    composite = lambda t: f(g(t))
    direct = (composite(x + epsilon) - composite(x - epsilon)) / (2 * epsilon)
    chain = chain_rule_numerical(f, g, x, epsilon)
    return {
        "chain_rule_result": round(chain, 8),
        "direct_derivative": round(direct, 8),
        "match": abs(chain - direct) < 1e-5,
    }

Explanation

  1. Chain rule: For composite function h(x) = f(g(x)), the derivative is h'(x) = f'(g(x)) * g'(x).
  2. Numerical derivatives: Use central differences: f'(a) = (f(a+eps) - f(a-eps)) / (2*eps) for O(eps^2) accuracy.
  3. Multi-function composition: For f_n(f_{n-1}(...f_1(x)...)), compute all intermediate values in a forward pass, then multiply all local derivatives in a backward pass.
  4. Verification: Compare the chain rule result with a direct numerical derivative of the composite function.

Complexity

  • Time: O(n) where n is the number of composed functions (each requires 2 function evaluations)
  • Space: O(n) to store intermediate values