Skip to article frontmatterSkip to article content

Stochastic Gradient Descent

Cornell University

Introduction

Recall we defined the loss function as

f(w)=1ni=1nfi(w)f(w) = \frac 1 n \sum^n_{i=1} f_i(w)

GD: wt+1=wtαf(wt)=wtαni=1nfi(wt)w_{t+1} = w_t - \alpha \nabla f(w_t) = w_t - \frac \alpha n \sum^n_{i=1} \nabla f_i(w_t)

SGD: Sample a random i[1,n]i \in [1,n], update wt+1=wtαfi(wt)w_{t+1} = w_t - \alpha \nabla f_i(w_t)

Minibatch SGD: Sample B ibi_b i.i.d from [1,n][1,n] with replacement (in practice, we don’t do replacement, but here is to make the proof below easier) , update wt+1=wtαBb=1Bfib(wt)w_{t+1} = w_t - \frac \alpha B \sum^B_{b=1} \nabla f_{i_b}(w_t)

Minibatch SGD Converges?

It’s not too difficult to show SGD converges just like GD through expectation, so we’ll focus on minibatch SGD. How do we know it converges? We want to show in expectation, it behaves the same as SGD, so if we perform one step minibatch SGD, we have

E[wt+1wt]=wtαf(wt)\mathbf E[w_{t+1} \mid w_t] = w_t - \alpha \nabla f(w_t)

Note in this current context, ibi_b are the base Random Variables, wt,wt+1w_t, w_{t+1} are functions of ibi_b and are also RV. α,B\alpha, B are constants.

The Easier Part

Similar to what we did in the GD convergence proof, we make some assumptions on our function ff:

  1. ff is L-smooth: f(w)f(u)2Luv2\| \nabla f(w) - \nabla f(u) \|_2 \le L \| u-v \|_2
  2. global min exists: f  s.t.  f(w)f\exists f^* \; s.t. \; f(w) \ge f^*

Call the average batch gradient

gt=1Bb=1Bfib(wt)g_t = \frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t)

First half of the proof is all the same as GD convergence proof, where we have

f(wt+1)=f(wtαf(wt))=f(wt)+0αηf(wtηf(wt))  dηf(wt)αf(wt)Tgt+α2L2gt2\begin{align} f(w_{t+1}) &= f(w_t - \alpha \nabla f(w_t)) \\ &= f(w_t) + \int^\alpha_0 \frac{\partial}{\partial\eta} f(w_t - \eta \nabla f(w_t)) \; d\eta \\ & \dots \\ &\le f(w_t) - \alpha \nabla f(w_t)^T g_t + \frac{\alpha^2 L}{2} \| g_t \|^2 \\ \end{align}

In the GD proof, we had gt=f(wt)g_t = \nabla f(w_t), so by adding some constraint αL1\alpha L \le 1, we can conclude it decreases at each step so eventually converges. However, here we can only hope it does so in expectation. On both sides take expectation given wtw_t, note the only thing random here is gtg_t. All the other values are either originally deterministic (constant α,L\alpha, L) or deterministic after given wtw_t (f(wt)\nabla f(w_t))

E[f(wt+1)wt]E[f(wt)αf(wt)Tgt+α2L2gt2wt]E[f(wt+1)wt]f(wt)αf(wt)TE[gtwt]+α2L2E[gt2wt]\begin{align} \mathbf E [ f(w_{t+1}) \mid w_t ] & \le \mathbf E \left[ f(w_t) - \alpha \nabla f(w_t)^T g_t + \frac{\alpha^2 L}{2} \| g_t \|^2 \mid w_t \right] \\ \mathbf E [ f(w_{t+1}) \mid w_t ] &\le f(w_t) - \alpha \nabla f(w_t)^T \mathbf E [ g_t | w_t ] + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ] \end{align}

Observe that for an arbitrary sample ii, its expected gradient is just the global gradient. To show this, apply definition of expectation.

E[fi(wt)wt]=j=1nP(i=j)fj(wt)=j=1n1nfj(wt)=f(wt)\mathbf E[\nabla f_i(w_t) \mid w_t] = \sum^n_{j=1} P(i=j) \nabla f_j(w_t) = \sum^n_{j=1} \frac 1 n \nabla f_j(w_t) =\nabla f(w_t)

So the batch average gradient is also just the global gradient. We can see this by applying linearity of expectation.

E[gtwt]=E[1Bb=1Bfib(wt)wt]=1Bb=1BE[fib(wt)wt]=f(wt)\mathbf E[g_t \mid w_t] = \mathbf E\left[\frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t) \mid w_t\right] = \frac 1 B \sum^B_{b=1} \mathbf E\left[ \nabla f_{i_b}(w_t) \mid w_t\right] = \nabla f(w_t)

Therefore, we can rewrite the above inequality as

E[f(wt+1)wt]f(wt)αf(wt)Tf(wt)+α2L2E[gt2wt]E[f(wt+1)wt]f(wt)αf(wt)2+α2L2E[gt2wt]\begin{align} \mathbf E [ f(w_{t+1}) \mid w_t ] &\le f(w_t) - \alpha \nabla f(w_t)^T \nabla f(w_t) + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]\\ \mathbf E [ f(w_{t+1}) \mid w_t ] &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ] \end{align}

Variance of Gradient

Assumption on Variance

Therefore, now all we have different is this last square term. Compare

GD: f(wt+1)f(wt)αf(wt)2+α2L2f(wt)2SGD: E[f(wt+1)wt]f(wt)αf(wt)2+α2L2E[gt2wt]\begin{align} \text{GD: } f(w_{t+1}) &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \| \nabla f(w_t) \|^2 \\ \text{SGD: }\mathbf E [ f(w_{t+1}) \mid w_t ] &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ] \end{align}

To tweak this term, we need an extra assumption on the variance of the gradient:

  1. the variance of the gradients is bounded: There exists a constant σ>0\sigma > 0, such that for a uniformly randomly drawn sample ii,
    E[fi(w)f(w)2]σ2\mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] \le \sigma^2

Meaning of this Bound

Note this is just saying

Var(f(w))σ2Var(\nabla f(w)) \le \sigma^2

But we cannot use the VarVar notation because it is only for a scalar random variable, not for a vector RV. For the same reason, we need to prove the classic Var(X)=E[X2]E2[X]Var(X) = \mathbf E[X^2] - \mathbf E^2[X] before using it. We prove it in the following by employing the fact that each ii will be drawn i.i.d.

E[fi(w)f(w)2]=1ni=1Nfi(w)f(w)2=1ni=1Nfi(w)22ni=1Nfi(w)Tf(w)+1ni=1Nf(w)2\begin{align} \mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] &= \frac{1}{n} \sum_{i=1}^N \| \nabla f_i(w) - \nabla f(w) \|^2\\ &= \frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2 - \frac 2 n \sum^N_{i=1} \nabla f_i(w)^T \nabla f(w) + \frac 1 n \sum^N_{i=1} \|\nabla f(w)\|^2 \\ \end{align}

Note in the second term, from the expectation of a specific sample’s gradient E[fi(wt)]=f(wt)\mathbf E[\nabla f_i(w_t)] = \nabla f(w_t) we have

21ni=1Nfi(w)T something=2f(w)T something2 \cdot \frac 1 n \sum^N_{i=1} \nabla f_i(w)^T \text{ something} = 2 \nabla f(w) ^T \text{ something}

so we can rewrite the above equation as

E[fi(w)f(w)2]=1ni=1Nfi(w)22f(w)Tf(w)+f(w)2=(1ni=1Nfi(w)2)f(w)2=E[fi(w)2]f(w)2=E[fi(w)2]E2[f(w)]\begin{align} \mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] &= \frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2 - 2 \nabla f(w) ^T \nabla f(w) + \|\nabla f(w)\|^2 \\ &= \left(\frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2 \right) - \|\nabla f(w)\|^2 \\ &= \mathbf{E}\left[ \| \nabla f_i(w) \|^2 \right] - \| \nabla f(w) \|^2 \\ &= \mathbf{E}\left[ \| \nabla f_i(w) \|^2 \right] - \mathbf{E}^2\left[ \| \nabla f(w) \| \right] \\ \end{align}

??? Add On to Variance ???

Applying Variance Bound

Now we are prepared to look at this term

E[gt2wt]=E[1Bb=1Bfib(wt)2wt]=1B2E[(b=1Bfib(wt))T(c=1Bfic(wt))wt]=1B2b=1Bc=1BE[fib(wt)Tfic(wt)wt]\begin{align} \mathbf E [ \| g_t \|^2 | w_t ] &= \mathbf E \left[\left\| \frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t) \right\|^2 \mid w_t \right] \\ &= \frac 1 {B^2} \mathbf E \left[\left( \sum^B_{b=1} \nabla f_{i_b}(w_t) \right)^T \left( \sum^B_{c=1} \nabla f_{i_c}(w_t) \right) \mid w_t \right] \\ &= \frac 1 {B^2} \sum^B_{b=1} \sum^B_{c=1}\mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_c}(w_t) \mid w_t \right] \\ \end{align}

Among these total B2B^2 pairs, we have BB pairs with the same subscripts and B(B1)B(B-1) pairs with distinct subscripts. For such pairs with distinct subscripts, we can apply linearity of expectation multiplication since they are drawn i.i.d. For pairs of the same subscript, note again expectation of fi(wt)2\| \nabla f_{i}(w_t)\|^2 is just the average across all NN training data.

E[gt2wt]=1B2(b=1BcbBE[fib(wt)Tfic(wt)wt]+b=1BE[fib(wt)Tfib(wt)wt])=1B2(b=1BcbBE[fib(wt)wt]TE[fic(wt)wt]+b=1BE[fib(wt)2wt])=1B2((B2B)f(wt)2+B  1ni=1Nfi(wt)2)=B1Bf(wt)2+  1B1ni=1Nfi(wt)2\begin{align} \mathbf E [ \| g_t \|^2 | w_t ] &= \frac 1 {B^2} \left( \sum^B_{b=1} \sum^B_{c\not=b}\mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_c}(w_t) \mid w_t \right] + \sum^B_{b=1} \mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_b}(w_t) \mid w_t \right] \right) \\ &= \frac 1 {B^2} \left( \sum^B_{b=1} \sum^B_{c\not=b}\mathbf E \left[ \nabla f_{i_b}(w_t) \mid w_t \right]^T \mathbf E\left[\nabla f_{i_c}(w_t) \mid w_t \right] + \sum^B_{b=1} \mathbf E \left[ \| \nabla f_{i_b}(w_t)\|^2 \mid w_t \right] \right) \\ &= \frac 1 {B^2} \left( (B^2-B) \| \nabla f(w_t)\|^2 + B \; \frac 1 n \sum^N_{i=1} \| \nabla f_{i}(w_t)\|^2 \right) \\ &= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \frac 1 n \sum^N_{i=1} \| \nabla f_{i}(w_t)\|^2 \\ \end{align}

From our assumption of bound on variance, we replace the last term with

E[gt2wt]=B1Bf(wt)2+  1B1ni=1Nfi(wt)2=B1Bf(wt)2+  1B(E[fi(w)f(w)2]+f(w)2)B1Bf(wt)2+  1B(σ2+f(w)2)f(wt)2+  σ2B\begin{align} \mathbf E [ \| g_t \|^2 | w_t ] &= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \frac 1 n \sum^N_{i=1} \mathbf \| \nabla f_{i}(w_t)\|^2 \\ &= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \left( \mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] + \|\nabla f(w)\|^2 \right) \\ &\le \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \left( \sigma^2 + \|\nabla f(w)\|^2 \right) \\ &\le \| \nabla f(w_t)\|^2 + \; \frac {\sigma^2} {B} \\ \end{align}

Finishing Up Derivation

We can finally look back to our definition

E[f(wt+1)wt]f(wt)αf(wt)2+α2L2E[gt2wt]f(wt)αf(wt)2+α2L2(f(wt)2+σ2B)f(wt)α(1αL2)f(wt)2+α2σ2L2B\begin{align} \mathbf E [ f(w_{t+1}) \mid w_t ] &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]\\ &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \left( \| \nabla f(w_t) \|^2 + \frac {\sigma^2} {B} \right)\\ &\le f(w_t) - \alpha \left( 1 - \frac{\alpha L}{2} \right) \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B} \end{align}

Remember we have old assumption that αL1\alpha L \le 1, so

E[f(wt+1)wt]f(wt)α2f(wt)2+α2σ2L2B\mathbf{E}[ f(w_{t+1}) \mid w_t] \le f(w_t) - \frac{\alpha}{2} \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B}

PL Condition

Let’s focus on function ff that satisfies PL condition. So we add our 4th assumption:

  1. f(x)22μ(f(x)f)\left\| \nabla f(x) \right\|^2 \ge 2 \mu \left( f(x) - f^* \right)

With this, we can write

E[f(wt+1)wt]f(wt)α2f(wt)2+α2σ2L2Bf(wt)αμ(f(wt)f)+α2σ2L2B\begin{align} \mathbf{E}[ f(w_{t+1}) \mid w_t] &\le f(w_t) - \frac{\alpha}{2} \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B}\\ &\le f(w_t) - \alpha \mu \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\ \end{align}

Subtract a ff^* on both sides of the inequality. On the left side, it goes straight into the expectation because of linearity of expectation

E[f(wt+1)wt]f(f(wt)f)αμ(f(wt)f)+α2σ2L2BE[f(wt+1)fwt](1αμ)(f(wt)f)+α2σ2L2B\begin{align} \mathbf{E}[ f(w_{t+1}) \mid w_t] - f^* &\le (f(w_t) - f^*) - \alpha \mu \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\ \mathbf{E}[ f(w_{t+1}) - f^*\mid w_t] &\le (1 - \alpha \mu) \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\ \end{align}

To get rid of the “given wtw_t” term, we apply the Law of Total Expectation E[E[XY]]=E[X]\mathbf E[\mathbf E [X | Y]] = \mathbf E[X] and take expected value on both sides

E[E[f(wt+1)fwt]]E[(1αμ)(f(wt)f)+α2σ2L2B]E[f(wt+1)f](1αμ)E[f(wt)f]+α2σ2L2B\begin{align} \mathbf{E}\left[ \mathbf{E}[ f(w_{t+1}) - f^*\mid w_t] \right] &\le \mathbf{E}\left[ (1 - \alpha \mu) \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B} \right] \\ \mathbf{E}[ f(w_{t+1}) - f^*] &\le (1 - \alpha \mu) \mathbf{E}\left[f(w_t) - f^* \right] + \frac{\alpha^2 \sigma^2 L}{2B} \\ \end{align}

Call ρt=E[f(wt)f]\rho_t = \mathbf{E}\left[f(w_t) - f^* \right]

ρt+1(1αμ)ρt+α2σ2L2B\rho_{t+1} \le (1 - \alpha \mu) \rho_t + \frac{\alpha^2 \sigma^2 L}{2B}

Observe when tt \to \infin,

ρ=(1αμ)ρ+α2σ2L2Bρ=ασ2L2μB\rho_{\infin} = (1 - \alpha \mu) \rho_\infin + \frac{\alpha^2 \sigma^2 L}{2B}\\ \rho_{\infin} = \frac{\alpha \sigma^2 L} {2 \mu B}

If we subtract this value in both sides of the inequality

ρt+1ασ2L2μB(1αμ)ρt+α2σ2L2Bασ2L2μB(1αμ)(ρtασ2L2μB)\begin{align} \rho_{t+1} - \frac{\alpha \sigma^2 L} {2 \mu B} &\le (1 - \alpha \mu) \rho_t + \frac{\alpha^2 \sigma^2 L}{2B} - \frac{\alpha \sigma^2 L} {2 \mu B} \\ &\le (1 - \alpha \mu) (\rho_t - \frac{\alpha \sigma^2 L} {2 \mu B}) \\ \end{align}

Now we are ready to look at what happens when we start from initialization t=0t=0 and runs TT iterations

ρtασ2L2μB(1αμ)T(ρ0ασ2L2μB)(1αμ)Tρ0\begin{align} \rho_{t} - \frac{\alpha \sigma^2 L} {2 \mu B} &\le (1 - \alpha \mu)^T (\rho_0 - \frac{\alpha \sigma^2 L} {2 \mu B}) \\ &\le (1 - \alpha \mu)^T \rho_0 \\ \end{align}

We apply the same trick of 1xex1-x \le e^{-x}

E[f(wT)f]ασ2L2μB(1αμ)TE[f(w0)f]exp(αμT)  E[f(w0)f]\begin{align} \mathbf{E}\left[f(w_T) - f^* \right] - \frac{\alpha \sigma^2 L} {2 \mu B} &\le (1 - \alpha \mu)^T \mathbf{E}\left[f(w_0) - f^* \right] \\ &\le exp(-\alpha \mu T) \; \mathbf{E}\left[f(w_0) - f^* \right] \\ \end{align}

In conclusion: (f(w0)f(w_0) only depends on initialization)

E[f(wT)f]exp(αμT)  (f(w0)f)+ασ2L2μB\mathbf{E}\left[f(w_T) - f^* \right] \le exp(-\alpha \mu T) \; (f(w_0) - f^*) + \frac{\alpha \sigma^2 L} {2 \mu B}

Comparing this with what we had with normal Gradient Descent

f(wT)fexp(μTL)(f(w0)f)f(w_{T}) - f^* \le exp(-\frac {\mu T} L) \left( f(w_0) - f^* \right)

In GD, the loss gap goes to 0 exponentially. In SGD, There is this ασ2L2μB\frac{\alpha \sigma^2 L} {2 \mu B} error term caused by running GD stochastically. The loss gap flattens out when reaches this noise ball.

One thing to notice that is if we have a fixed learning rate, this noise ball will evaluate to a constant and we can never reach the minimum. However, if we have an adaptive learning rate that decreases as a function of step, we can actually push this noise ball to 0 and reach the minimum at the end.