Skip to article frontmatterSkip to article content

SGD with Momentum

Cornell University

Recall from previous lecture, the running time of GD on a strongly convex (PL-condition) function depends only on condition number κ\kappa. When the condition number is high, convergence can become slow.

Tκlog(f(w0)fϵ)T \ge \kappa \log\left( \frac{f(w_0) - f^*}{\epsilon} \right)

So how can we speed up gradient descent when the condition is high? There are three common solutions:

We introduce momentum here. A direct analysis would be messy, so we use a very simple example to give some intuition.

Simple Quadratic Function

The simplest possible setting with a non-1 condition number is a 2D quadratic. To get a high condition number, we want in the Hessian matrix, the biggest (bigger) value to be very big and the smallest (smaller) to be very small. Consider the following example with a>ba \gt b. Here aa is just the curvature of the first dimension and bb is the curvature of the second dimension.

f(w)=f(w1,w2)=a2w12+b2w22=12[w1w2]T[a00b][w1w2]f(w) = f(w_1, w_2) = \frac{a}{2} w_1^2 + \frac{b}{2} w_2^2 = \frac{1}{2} \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^T \begin{bmatrix} a & 0 \\ 0 & b \end{bmatrix} \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}

The second derivative matrix is the constant

2f(w)=[a00b]\nabla^2 f(w) = \begin{bmatrix} a & 0 \\ 0 & b \end{bmatrix}

By definition of a condition number in linear algebra, we know κ=ab\kappa = \frac a b. By its definition in ML optimization problem, we know κ=Lμ\kappa = \frac L \mu, so we can just set a=L,b=μa = L, b = \mu, so

f(w)=L2w12+μ2w22=12[w1w2]T[L00μ][w1w2]f(w)=[L00μ]w2f(w)=[L00μ]\begin{align} f(w) = \frac{L}{2} w_1^2 + \frac{\mu}{2} w_2^2 &= \frac{1}{2} \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}^T \begin{bmatrix} L & 0 \\ 0 & \mu \end{bmatrix} \begin{bmatrix} w_1 \\ w_2 \end{bmatrix} \\ \nabla f(w) &= \begin{bmatrix} L & 0 \\ 0 & \mu \end{bmatrix} w\\ \nabla^2 f(w) &= \begin{bmatrix} L & 0 \\ 0 & \mu \end{bmatrix} \end{align}

With this, we can write our update step as

wt+1=wtαf(w)=wtα[L00μ]wt=[1αL001αμ]wt\begin{align} w_{t+1} &= w_t - \alpha \nabla f(w) \\ &= w_t - \alpha \begin{bmatrix} L & 0 \\ 0 & \mu \end{bmatrix} w_t \\ &= \begin{bmatrix} 1 - \alpha L & 0 \\ 0 & 1- \alpha \mu \end{bmatrix} w_t \end{align}

If we run the update TT steps, we will have (from now on, TT will most likely denote number of iterations instead of transpose.)

wT=[1αL001αμ]Tw0=[(1αL)T00(1αμ)T]w0=[(1αL)T(w0)1(1αμ)T(w0)2]\begin{align} w_{T} &= \begin{bmatrix} 1 - \alpha L & 0 \\ 0 & 1- \alpha \mu \end{bmatrix}^T w_0 \\ &= \begin{bmatrix} (1 - \alpha L)^T & 0 \\ 0 & (1- \alpha \mu)^T \end{bmatrix} w_0 \\ & = \begin{bmatrix} (1 - \alpha L)^{T} (w_0)_1 \\ (1 - \alpha \mu)^{T} (w_0)_2 \end{bmatrix} \end{align}

Feed it into ff, we have

f(wT)=L2(1αL)2T((w0)1)2+μ2(1αμ)2T((w0)2)2f(w_T) = \frac{L}{2} (1 - \alpha L)^{2T} \left( (w_0)_1 \right)^2 + \frac{\mu}{2} (1 - \alpha \mu)^{2T} \left( (w_0)_2 \right)^2

Therefore, the final value of f(wt)f(w_t) will be dominated by the exponential term: one of (1αL)2T(1 - \alpha L)^{2T} or (1αμ)2T(1 - \alpha \mu)^{2T}

To minimize f(wt)f(w_t), we have to minimize 1αL| 1 - \alpha L| and 1αμ|1 - \alpha \mu| at the same time. That is to minimize max(1αL,1αμ)\max(| 1 - \alpha L|, |1 - \alpha \mu|) Note if we just look at the first dimension, so we just minimize 1αL| 1 - \alpha L|, we set α=1L\alpha = \frac 1 L . With L>μL \gt \mu, this will always be a small number. Therefore, for the first dimension with a respective larger curvature LL, we want a smaller learning rate. On the other hand, if we only look at the second dimension, where the respective curvature is the smaller μ\mu, we should want a higher learning rate.

Now that we have to minimize both at the same time, we are forced to choose something in the middle. We will always reach minimum when we have αL1=1αμ\alpha L - 1 = 1 - \alpha \mu, so α=2L+μ\alpha = \frac 2 {L + \mu}. Substitute this α\alpha in,

max(1αL,1αμ)=LμL+μ=κ1κ+1=12κ+1\max(|{1 - \alpha L}|, |{1 - \alpha \mu}|) = \frac{L - \mu}{L + \mu} = \frac{\kappa - 1}{\kappa + 1} = 1 - \frac{2}{\kappa + 1}

As we said, the bigger of these two terms will dominate the final value of f(wT)f(w_T), so

f(wT)=O((12κ+1)2T)f(w_T) = \mathcal O\left( (1 - \frac{2}{\kappa + 1})^{2T} \right)

We know 1xex1-x \approx e^{-x} around x=1x = 1, so (here 12κ+111 - \frac{2}{\kappa + 1} \approx 1 with κ\kappa being large)

f(wT)=O(exp(4Tκ+1))f(w_T) = \mathcal O\left( exp(- \frac{4T}{\kappa + 1}) \right)

Therefore, even in the simplest setting, we can’t get rid of this κ+1\frac {} {\kappa + 1} term.

Polyak Momentum

We want to sort of detect high vs low curvature during GD. The idea is:

Therefore, we want to make steps smaller when gradients reverse sign and larger when gradients are consistently in the same direction. Polyak momentum does this

wt+1=wtαf(wt)+β(wtwt1)w_{t+1} = w_t - \alpha \nabla f(w_t) + \beta (w_t - w_{t-1})

The intuition is that

This is equivalent to

wt+1=wtαf(wt)+β(wtwt1)wt+1wt=αf(wt)+β(wtwt1)mt+1=αf(wt)+βmtwt+1=wt+mt+1\begin{align*} w_{t+1} &= w_t - \alpha \nabla f(w_t) + \beta (w_t - w_{t-1})\\ w_{t+1} - w_t &= - \alpha \nabla f(w_t) + \beta (w_t - w_{t-1})\\ m_{t+1} &= -\alpha \nabla f(w_t) + \beta m_t\\ w_{t+1} &= w_t + m_{t+1} \end{align*}

Stage Transition Matrix

Go back to our simple example, denote A=[L00μ]A = \begin{bmatrix} L & 0 \\ 0 & \mu \end{bmatrix}, so

wt+1=wtαAwt+β(wtwt1)\begin{align} w_{t+1} &= w_t - \alpha A w_t + \beta (w_t - w_{t-1}) \end{align}

We can write this update process as a matrix operation too. The matrix here will be a block matrix, where each entry is actually a 2×22 \times 2 matrix.

[wt+1wt]=[wtαAwt+β(wtwt1)wt]=[(1+β)IαAβII0][wtwt1][wT+1wT]=[(1+β)IαAβII0]T[w1w0]\begin{align} \begin{bmatrix} w_{t+1} \\ w_t \end{bmatrix} & = \begin{bmatrix} w_t - \alpha A w_t + \beta (w_t - w_{t-1}) \\ w_t \end{bmatrix}\\ & = \begin{bmatrix} (1 + \beta) I - \alpha A & -\beta I \\ I &0 \end{bmatrix} \begin{bmatrix} w_{t} \\ w_{t-1} \end{bmatrix}\\ \begin{bmatrix} w_{T+1} \\ w_T \end{bmatrix} & = \begin{bmatrix} (1 + \beta) I - \alpha A & -\beta I \\ I &0 \end{bmatrix}^T \begin{bmatrix} w_{1} \\ w_{0} \end{bmatrix} \end{align}

This block matrix in whole is actually a 4×44 \times 4 matrix that transforms a vector in R4\R^4 to R4\R^4, so if we write in the basis form

[w1w0]=c1u1+c2u2+c3u3+c4u4\begin{bmatrix} w_{1} \\ w_{0} \end{bmatrix} = c_1u_1 + c_2u_2 + c_3u_3 + c_4u_4

And if we find the eigenvalues of this block matrix, denote them as λ1,λ2,λ3,λ4\lambda_1, \lambda_2, \lambda_3, \lambda_4, we can write

[wT+1wT]=λ1Tc1u1+λ2Tc2u2+λ3Tc3u3+λ4Tc4u4\begin{bmatrix} w_{T+1} \\ w_T \end{bmatrix} = \lambda_1^Tc_1u_1 + \lambda_2^Tc_2u_2 + \lambda_3^Tc_3u_3 + \lambda_4^Tc_4u_4

As before, what dominates the value of f(wt)f(w_t) will be the biggest among the exponential λT\lambda^T term. Therefore, we want to minimize all these eigenvalues at the same time.

In addition, recall the square term in f(w)f(w)

f(wT)=L2wT12+μ2wT22f(w_T) = \frac{L}{2} w_{T1}^2 + \frac{\mu}{2} w_{T2}^2

Therefore, we actually only care about the magnitude of wTw_T. That is, we want to minimize the magnitude of all these eigenvalues at the same time.

minwT    minλ1,2,3,4  max(λ1Tc1u1,λ2Tc2u2,λ3Tc3u3,λ4Tc4u4)    minλ1,2,3,4max(λ1T,λ2T,λ3T,λ4T)\begin{align} &\min \|w_T\| \\ \iff &\min_{\lambda_{1,2,3,4}} \; \max \left(\| \lambda_1^Tc_1u_1 \|, \|\lambda_2^Tc_2u_2\| , \|\lambda_3^Tc_3u_3\| , \|\lambda_4^Tc_4u_4 \| \right)\\ \iff &\min_{\lambda_{1,2,3,4}} \max \left( \| \lambda_1^T \|, \|\lambda_2^T\| , \|\lambda_3^T\| , \|\lambda_4^T \|\right) \end{align}

Analyzing Eigenvalues

We start to analyze the eigenvalues of this block matrix.

We write this block matrix out:

[1+βαL0β001+βαμ0β10000100]\begin{bmatrix} 1+\beta - \alpha L & 0 & -\beta & 0\\ 0 & 1+\beta - \alpha \mu & 0 & -\beta\\ 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0 \end{bmatrix}

Since we are analyzing the eigenvalues and the basis are not of any importance, we swap the 2nd column with 3rd column and swap 2nd row with 3rd row. This is equivalent to swap the 2nd and 3rd basis in domain vector space and also swap 2nd and 3rd basis in the codomain space. Again, this is fine because we only care about the eigenvalues.

[1+βαLβ001000001+βαμβ0010]\begin{bmatrix} 1+\beta - \alpha L & -\beta & 0 & 0\\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1+\beta - \alpha \mu & -\beta\\ 0 & 0 & 1 & 0 \end{bmatrix}

We write this new matrix also in block matrix form,

[B00B],where B=[1+βαχβ10],χ=L or μ repsectively\begin{bmatrix} B & 0 \\ 0 & B \end{bmatrix}, \text{where } B = \begin{bmatrix} 1+\beta - \alpha \chi & -\beta \\ 1 & 0 \\ \end{bmatrix}, \chi = \text{$L$ or $\mu$ repsectively}

This new block matrix is also in diagonal form, so to solve for its eigenvalues, we just have to solve for BB’s eigen values.

Recall that we want to minimize all of eigen values’ norms all at the same time. We achieve this when they all have the same norm. For BB specifically, it has two eigenvalues λ1,λ2\lambda_1, \lambda_2 and we want λ1=λ2|\lambda_1| = |\lambda_2|.

Write out the characteristic polynomial of matrix BB

det(BλI)=0det(BλI)=det([1+βαχλβ1λ])=λ(λ(1+βαχ))+β=λ2(1+βαχ)λ+β\begin{align} det(B - \lambda I) &= 0\\ det(B - \lambda I) &= det\left( \begin{bmatrix} 1+\beta - \alpha \chi - \lambda & -\beta \\ 1 & -\lambda \\ \end{bmatrix} \right) \\ &= \lambda \left( \lambda - (1 + \beta - \alpha \chi) \right) + \beta \\ &= \lambda^2 - (1 + \beta - \alpha \chi) \lambda + \beta \end{align}

Solve this quadratic equation, we have

λ=(1+βαχ)±(1+βαχ)24β2\lambda = \frac{ (1 + \beta - \alpha \chi) \pm \sqrt{ (1 + \beta - \alpha \chi)^2 - 4 \beta } }{ 2 }

Two solutions to this equation have the same norm when we have them to be the same or they are complex numbers. That is when

(1+βαχ)24β0(1 + \beta - \alpha \chi)^2 - 4 \beta \le 0

To find the exact value of λ|\lambda|, we don’t have to go through the process of finding norm of a complex number. In fact, we just recall that product of all eigenvalues is equal to the determinant of this matrix, so

λ2=det(B)=βλ=β\begin{align} \lambda^2 &= det(B) = \beta\\ |\lambda| &= \sqrt \beta \end{align}

Therefore, to minimize λ|\lambda|, we actually need to minimize β\sqrt \beta. So we have this new linear optimization problem to solve:

minα,ββs.t. (1+βαL)24β0(1+βαμ)24β0\begin{align} &\min_{\alpha, \beta} \sqrt \beta & \\ &\textrm{s.t. } \begin{matrix} (1 + \beta - \alpha L)^2 - 4 \beta \le 0 \\ (1 + \beta - \alpha \mu)^2 - 4 \beta \le 0 \end{matrix} \end{align}

This is a special case of Karush–Kuhn–Tucker conditions (KTT). When we solve a linear programming problem of kk variables and nn inequalities, kk of these nn inequalities will actually hold as equality. In this case, this minimization problem is solved when 2 of these 2 inequality constraints achieve equality. . So we solve

(1+βαL)24β=0(1+βαμ)24β=0(1 + \beta - \alpha L)^2 - 4 \beta = 0 \\ (1 + \beta - \alpha \mu)^2 - 4 \beta = 0

which gives us

2β=1+βαL=1+βαμ\begin{align} 2 \sqrt{\beta} &= | 1 + \beta - \alpha L | \\ &= | 1 + \beta - \alpha \mu | \end{align}

Since we have L>μL \gt \mu, to make the two absolute values to be equal, we must have

2β=1+βαL2β=1+βαμ\begin{align} -2 \sqrt{\beta} &= 1 + \beta - \alpha L \\ 2 \sqrt{\beta} &= 1 + \beta - \alpha \mu \end{align}

Solve for both α\alpha and β\beta, we have

α=2+2βL+μ and β=κ1κ+1=12κ+1.\alpha = \frac{2 + 2 \beta}{L + \mu} \hspace{1em}\text{ and }\hspace{1em} \sqrt{\beta} = \frac{\sqrt{\kappa} - 1}{\sqrt{\kappa} + 1} = 1 - \frac{2}{\sqrt{\kappa} + 1}.

Recall that the norm of wtw_t will be dominated with the one with max eigenvalue (C=maxcC = \max \|c\|)

wt=maxi[1,,4]λiTciui=maxi[1,,4]λiTC=βTCw_t = \max_{i \in [1,\dots,4]} \| \lambda_i^T c_i u_i\| = \max_{i \in [1,\dots,4]} \| \lambda_i^T\|C = \sqrt \beta ^ T C

Therefore, we have (C=max(L,μ)2C2C' = \frac {\max(L, \mu)} 2 C^2)

f(wt)=Cwt2=Cβ2T=C(12κ+1)2TCexp(22Tκ+1)\begin{align} f(w_t) &= C' |w_t|^2 \\ &= C' \sqrt \beta ^{2T} \\ &= C' (1 - \frac 2 {\sqrt \kappa +1})^{2T}\\ &\le C' \exp(- \frac {2 \cdot 2T} {\sqrt \kappa +1}) \end{align}

Therefore, our function converges to a given interval ϵ\epsilon (remember this is a quadratic function), when we have TT satisfies the following condition

f(wT)ϵTκ+14log(Cϵ)f(w_T) \le \epsilon \\ T \ge \frac{\sqrt{\kappa} + 1}{4} \log (\frac C \epsilon)

So

O(T)=κ\mathcal O (T) = \sqrt \kappa

Recall a normal GD convergence rate is

O(T)=κ\mathcal O (T) = \kappa

Therefore, we have shown that using momentum with GD on this simple example of quadratic function has a faster convergence rate than a vanilla GD. The result does generalize to more general cases with other kinds of functions.

When we use momentum with SGD, it is not guaranteed that it gives a better result, but people still use it.

Nesterov Momentum

One disadvantage of Polyak momentum is that the momentum term is not guaranteed to point to the right direction. Also, it is only guaranteed to have this nice acceleration for quadratics. Therefore, we introduce Nesterov Momentum, which works for general strongly convex objectives.

Polyak:

mt+1=βmtαf(wt)wt+1=wt+mt+1\begin{align*} m_{t+1} &= \beta m_t - \alpha \nabla f(w_t) \\ w_{t+1} &= w_t + m_{t+1} \end{align*}

Nesterov:

mt+1=βmtαf(wt+βmt)wt+1=wt+mt+1.\begin{align*} m_{t+1} &= \beta m_t - \alpha \nabla f(w_t + \beta m_t) \\ w_{t+1} &= w_t + m_{t+1}. \end{align*}

Difference: instead of calculating the momentum term at the current position, we pretend to have already taken one step and calculate the momentum term there.