Implement the Chain Rule for computing derivatives of composite functions. Given a composition f(g(x)), compute the derivative f'(g(x)) * g'(x) numerically and symbolically for simple cases.
import math
def chain_rule_numerical(f, g, x, epsilon=1e-7):
"""
Compute d/dx f(g(x)) using the chain rule numerically.
f, g: callable functions
"""
# g'(x)
g_prime = (g(x + epsilon) - g(x - epsilon)) / (2 * epsilon)
# f'(g(x))
gx = g(x)
f_prime_at_gx = (f(gx + epsilon) - f(gx - epsilon)) / (2 * epsilon)
# Chain rule: f'(g(x)) * g'(x)
return f_prime_at_gx * g_prime
def chain_rule_multi(funcs: list, x: float, epsilon: float = 1e-7):
"""
Compute derivative of composed functions f_n(f_{n-1}(...f_1(x)...)).
Uses the generalized chain rule.
"""
n = len(funcs)
# Forward pass: compute all intermediate values
values = [x]
for f in funcs:
values.append(f(values[-1]))
# Backward pass: accumulate derivatives
derivative = 1.0
for i in range(n):
v = values[i]
f = funcs[i]
f_prime = (f(v + epsilon) - f(v - epsilon)) / (2 * epsilon)
derivative *= f_prime
return derivative
def verify_chain_rule(f, g, x, epsilon=1e-7):
"""Verify chain rule by comparing with direct numerical differentiation."""
composite = lambda t: f(g(t))
direct = (composite(x + epsilon) - composite(x - epsilon)) / (2 * epsilon)
chain = chain_rule_numerical(f, g, x, epsilon)
return {
"chain_rule_result": round(chain, 8),
"direct_derivative": round(direct, 8),
"match": abs(chain - direct) < 1e-5,
}f'(a) = (f(a+eps) - f(a-eps)) / (2*eps) for O(eps^2) accuracy.