Personal Blog by Galen Wong

Nesterov Momentum Equivalence Derivation

That Missing Piece in the Papers and the Slides

2020-02-085 min read

During the Winter Quarter of 2020 at UCLA, I am/was taking a class on neural network and deep learning. The course is numbered ECE C147/247. It was a course modified from the famous Stanford CS 231n.

Note: I assume you (the reader) already knows what is momentum update, and how Nesterov Momentum is different from momentum.

In the chapter about optimization on training (the gradient descent stage), we learned about a technique called Nesterov Momentum used to update the weight to improve training speed.

The formulation of Nesterov Momentum update is given as

Updating momentum:

vαvϵθJ(θ+αv)\mathbf{v} \leftarrow \alpha\mathbf{v} - \epsilon\nabla_{\theta} J(\theta + \alpha\mathbf{v})

Updating parameter:

θθ+v\theta \leftarrow \theta + \mathbf{v}

Here JJ is the loss function. However, this formulation is difficult to implement in code. In code, we usually have θJ(θ)\nabla_\theta J(\theta) available to us instead of θJ(θ+αv)\nabla_{\theta} J(\theta + \alpha\mathbf{v}).

To solve this problem, an alternative formulation is given. We first let

θ~old=θold+αvold\tilde{\theta}_\text{old} = \theta_\text{old} + \alpha\mathbf{v}_\text{old}

Then, the formulation becomes

vnew=αvoldϵθ~oldJ(θ~old)\mathbf{v}_\text{new} = \alpha\mathbf{v}_\text{old} - \epsilon\nabla_{\tilde{\theta}_\text{old}} J(\tilde{\theta}_\text{old})
θ~new=θ~old+vnew+α(vnewvold)\tilde{\theta}_\text{new} = \tilde{\theta}_\text{old} +\mathbf{v}_\text{new} + \alpha(\mathbf{v}_\text{new} - \mathbf{v}_\text{old})

In class, there were no explanation given on how the 2 formulations are equivalent. Instead, a hint is given to us that the proof can be done through a substitution of variable. This formulation still confuses me a lot.

There are multiple points of confusion here:

  1. How do we calculate

θ~oldJ(θ~old)\nabla_{\tilde\theta_\text{old}}J(\tilde\theta_\text{old})? 2. What is the actual meaning of θ~\tilde\theta? 3. What are we storing as our parameter?

I then sought help from the Stanford course webpage on this chapter. They pointed to two papers on the derivation of Nesterov momentum. Only one of them (session 3.5) discusses the alternative derivation, but it simply stated the 2 formulation without proof. I understand the target audience of the paper are very likely capable to figure out the equivalence but I am not one of them.

The paper did offer some insight:

the parameters θ~\tilde\theta are a completely equivalent replacement of θ\theta.

This offers the answer to confusion 3. We are storing θ~\tilde\theta as the parameter. In code, what is given to us the the derivative of J(θ~)J(\tilde\theta) with respect to θ\theta. Using chain rule, we can show that

θ~=θ+αvθθ~=1\begin{aligned} \tilde\theta &= \theta + \alpha\mathbf{v}\\ \therefore \nabla_{\theta}\tilde\theta &= 1\\ \end{aligned}
θJ(θ~)=θθ~θ~J(θ~)θJ(θ~)=θ~J(θ~)\begin{aligned} \nabla_{\theta}J(\tilde\theta) &= \nabla_{\theta}\tilde\theta \nabla_{\tilde\theta}J(\tilde\theta)\\ \therefore \nabla_{\theta}J(\tilde\theta) &= \nabla_{\tilde\theta}J(\tilde\theta)\\ \end{aligned}

This answer confusion 1, since in code we are given gradient of loss with θ~\tilde\theta with respect to θ\theta and it is equivalent to gradient with respect to θ~\tilde\theta. Now back to confusion 2, what is the meaning of θ~\tilde\theta? Is it okay to use it as the parameter for prediction and testing? The answer is yes. By definition, θ~\tilde\theta is simply the parameter after one momentum update. Therefore, we can think of it as a parameter value that is closer to the optimum.

I had to take the time to do the derivation on my own. Here, I offer the derivation of the equivalence. First, we have the substitution and the original formulation.

θ~old=θold+αvold(1)\tag{1} \tilde{\theta}_\text{old} = \theta_\text{old} + \alpha\mathbf{v}_\text{old}
vnew=αvoldϵθoldJ(θold+αvold)(2)\tag{2} \mathbf{v}_\text{new} = \alpha\mathbf{v}_\text{old} - \epsilon\nabla_{\theta_\text{old}} J(\theta_\text{old} + \alpha\mathbf{v}_\text{old})
θnew=θold+vnew(3)\tag{3} \theta_\text{new} = \theta_\text{old} + \mathbf{v}_\text{new}

We substitute (1)(1) into (2)(2) and apply the relationship θJ(θ~)=θ~J(θ~)\nabla_{\theta}J(\tilde\theta) = \nabla_{\tilde\theta}J(\tilde\theta):

vnew=αvoldϵθoldJ(θ~old)vnew=αvoldϵθ~oldJ(θ~old)(4)\mathbf{v}_\text{new} = \alpha\mathbf{v}_\text{old} - \epsilon\nabla_{\theta_\text{old}} J(\tilde\theta_\text{old})\\ \tag{4} \mathbf{v}_\text{new} = \alpha\mathbf{v}_\text{old} - \epsilon\nabla_{\tilde\theta_\text{old}} J(\tilde\theta_\text{old})

We can rewrite (1)(1) as θ=θ~αv\theta = \tilde\theta - \alpha\mathbf{v}. Then, we substitute (1)(1) into (3)(3). Note that we have to substitute both θnew\theta_\text{new} and θold\theta_\text{old}.

θ~newαvnew=θ~oldαvold+vnewθ~new=θ~old+vnew+α(vnewvold)(5)\tilde\theta_\text{new} - \alpha\mathbf{v}_\text{new} = \tilde\theta_\text{old} - \alpha\mathbf{v}_\text{old} + \mathbf{v}_\text{new}\\ \tag{5} \tilde\theta_\text{new} = \tilde\theta_\text{old} + \mathbf{v}_\text{new} + \alpha(\mathbf{v}_\text{new} - \mathbf{v}_\text{old})

Now we have arrived at the alternative formulation with (4)(4) and (5)(5). Therefore, in the code, the new parameter is θ~new\tilde\theta_\text{new} that is returned instead of θnew\theta_\text{new}. Again, the idea is that the actual parameter stored and used for evaluation is always updated with the momentum. Therefore, the gradient is taken with loss is the parameter updated with the momentum.