author

Kien Duong

July 12, 2024

The Chain Rule

1. What is chain rule?

The chain rule is a fundamental theorem in used to compute the derivative of a composite function.

For example, if we have two functions \(y = f(u)\) and \(u = g(x)\), we need to find the derivative of y with respect to x by this function \(y = f(g(x))\). Using the chain rule, we have:

\[ \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} \]

2. Proof

Let \(h(x) = f(g(x))\). Based on the definition of the derivative, we have:

\[ h'(x) = \lim_{\Delta x \to 0} \frac{h(x + \Delta x) – h(x)}{\Delta x} \]

\[ \Rightarrow h'(x) = \lim_{\Delta x \to 0} \frac{f(g(x + \Delta x)) – f(g(x))}{\Delta x} \]

Let \(\Delta u = g(x + \Delta x) – g(x)\)

\[ \Rightarrow h'(x) = \lim_{\Delta x \to 0} \left( \frac{f(g(x + \Delta x)) – f(g(x))}{\Delta u} \cdot \frac{\Delta u}{\Delta x} \right) ~~~ (1) \]

Based on the definition the limit of a product is equal to the product of the limits:

\[ \lim_{x \to a} [f(x) \cdot g(x)] = \left( \lim_{x \to a} f(x) \right) \cdot \left( \lim_{x \to a} g(x) \right) \]

And \(\Delta x \) approaches 0, \(\Delta g \)  also approaches 0. So (1) can be transformed to:

\[ h'(x) = \left( \lim_{\Delta u \to 0} \frac{f(g(x + \Delta u)) – f(g(x))}{\Delta u} \right) \cdot \left( \lim_{\Delta x \to 0} \frac{\Delta u}{\Delta x} \right) \]

\[\Rightarrow h'(x) = f'(g(x)) \cdot g'(x) \]

\[ \Rightarrow \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} \]

3. Chain Rule in Machine Learning

The chain rule is a fundamental concept in calculus that plays a crucial role in Machine Learning, particularly in the context of training neural networks.

During the forward pass, inputs are passed through the network layers to compute the output. The output is compared with the true labels to calculate the loss.

Using the chain rule, gradients of the loss function with respect to each weight are computed. This is crucial for updating the weights to minimize the loss.

 

Recent Blogs