Stochastic Gradient Descent September 19, 2022
Introduction ¶ Recall we defined the loss function as
f ( w ) = 1 n ∑ i = 1 n f i ( w ) f(w) = \frac 1 n \sum^n_{i=1} f_i(w) f ( w ) = n 1 i = 1 ∑ n f i ( w ) GD: w t + 1 = w t − α ∇ f ( w t ) = w t − α n ∑ i = 1 n ∇ f i ( w t ) w_{t+1} = w_t - \alpha \nabla f(w_t) = w_t - \frac \alpha n \sum^n_{i=1} \nabla f_i(w_t) w t + 1 = w t − α ∇ f ( w t ) = w t − n α ∑ i = 1 n ∇ f i ( w t )
SGD: Sample a random i ∈ [ 1 , n ] i \in [1,n] i ∈ [ 1 , n ] , update w t + 1 = w t − α ∇ f i ( w t ) w_{t+1} = w_t - \alpha \nabla f_i(w_t) w t + 1 = w t − α ∇ f i ( w t )
Minibatch SGD: Sample B i b i_b i b i.i.d from [ 1 , n ] [1,n] [ 1 , n ] with replacement (in practice, we don’t do replacement, but here is to make the proof below easier) , update w t + 1 = w t − α B ∑ b = 1 B ∇ f i b ( w t ) w_{t+1} = w_t - \frac \alpha B \sum^B_{b=1} \nabla f_{i_b}(w_t) w t + 1 = w t − B α ∑ b = 1 B ∇ f i b ( w t )
Minibatch SGD Converges? ¶ It’s not too difficult to show SGD converges just like GD through expectation, so we’ll focus on minibatch SGD. How do we know it converges? We want to show in expectation, it behaves the same as SGD, so if we perform one step minibatch SGD, we have
E [ w t + 1 ∣ w t ] = w t − α ∇ f ( w t ) \mathbf E[w_{t+1} \mid w_t] = w_t - \alpha \nabla f(w_t) E [ w t + 1 ∣ w t ] = w t − α ∇ f ( w t ) Note in this current context, i b i_b i b are the base Random Variables, w t , w t + 1 w_t, w_{t+1} w t , w t + 1 are functions of i b i_b i b and are also RV. α , B \alpha, B α , B are constants.
The Easier Part ¶ Similar to what we did in the GD convergence proof, we make some assumptions on our function f f f :
f f f is L-smooth: ∥ ∇ f ( w ) − ∇ f ( u ) ∥ 2 ≤ L ∥ u − v ∥ 2 \| \nabla f(w) - \nabla f(u) \|_2 \le L \| u-v \|_2 ∥∇ f ( w ) − ∇ f ( u ) ∥ 2 ≤ L ∥ u − v ∥ 2 global min exists: ∃ f ∗ s . t . f ( w ) ≥ f ∗ \exists f^* \; s.t. \; f(w) \ge f^* ∃ f ∗ s . t . f ( w ) ≥ f ∗ Call the average batch gradient
g t = 1 B ∑ b = 1 B ∇ f i b ( w t ) g_t = \frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t) g t = B 1 b = 1 ∑ B ∇ f i b ( w t ) First half of the proof is all the same as GD convergence proof, where we have
f ( w t + 1 ) = f ( w t − α ∇ f ( w t ) ) = f ( w t ) + ∫ 0 α ∂ ∂ η f ( w t − η ∇ f ( w t ) ) d η … ≤ f ( w t ) − α ∇ f ( w t ) T g t + α 2 L 2 ∥ g t ∥ 2 \begin{align}
f(w_{t+1}) &= f(w_t - \alpha \nabla f(w_t)) \\
&= f(w_t) + \int^\alpha_0 \frac{\partial}{\partial\eta} f(w_t - \eta \nabla f(w_t)) \; d\eta \\
& \dots \\
&\le f(w_t) - \alpha \nabla f(w_t)^T g_t + \frac{\alpha^2 L}{2} \| g_t \|^2 \\
\end{align} f ( w t + 1 ) = f ( w t − α ∇ f ( w t )) = f ( w t ) + ∫ 0 α ∂ η ∂ f ( w t − η ∇ f ( w t )) d η … ≤ f ( w t ) − α ∇ f ( w t ) T g t + 2 α 2 L ∥ g t ∥ 2 In the GD proof, we had g t = ∇ f ( w t ) g_t = \nabla f(w_t) g t = ∇ f ( w t ) , so by adding some constraint α L ≤ 1 \alpha L \le 1 αL ≤ 1 , we can conclude it decreases at each step so eventually converges. However, here we can only hope it does so in expectation. On both sides take expectation given w t w_t w t , note the only thing random here is g t g_t g t . All the other values are either originally deterministic (constant α , L \alpha, L α , L ) or deterministic after given w t w_t w t (∇ f ( w t ) \nabla f(w_t) ∇ f ( w t ) )
E [ f ( w t + 1 ) ∣ w t ] ≤ E [ f ( w t ) − α ∇ f ( w t ) T g t + α 2 L 2 ∥ g t ∥ 2 ∣ w t ] E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∇ f ( w t ) T E [ g t ∣ w t ] + α 2 L 2 E [ ∥ g t ∥ 2 ∣ w t ] \begin{align}
\mathbf E [ f(w_{t+1}) \mid w_t ]
& \le
\mathbf E \left[ f(w_t) - \alpha \nabla f(w_t)^T g_t + \frac{\alpha^2 L}{2} \| g_t \|^2 \mid w_t \right] \\
\mathbf E [ f(w_{t+1}) \mid w_t ]
&\le
f(w_t) - \alpha \nabla f(w_t)^T \mathbf E [ g_t | w_t ] + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]
\end{align} E [ f ( w t + 1 ) ∣ w t ] E [ f ( w t + 1 ) ∣ w t ] ≤ E [ f ( w t ) − α ∇ f ( w t ) T g t + 2 α 2 L ∥ g t ∥ 2 ∣ w t ] ≤ f ( w t ) − α ∇ f ( w t ) T E [ g t ∣ w t ] + 2 α 2 L E [ ∥ g t ∥ 2 ∣ w t ] Observe that for an arbitrary sample i i i , its expected gradient is just the global gradient. To show this, apply definition of expectation.
E [ ∇ f i ( w t ) ∣ w t ] = ∑ j = 1 n P ( i = j ) ∇ f j ( w t ) = ∑ j = 1 n 1 n ∇ f j ( w t ) = ∇ f ( w t ) \mathbf E[\nabla f_i(w_t) \mid w_t] = \sum^n_{j=1} P(i=j) \nabla f_j(w_t) = \sum^n_{j=1} \frac 1 n \nabla f_j(w_t) =\nabla f(w_t) E [ ∇ f i ( w t ) ∣ w t ] = j = 1 ∑ n P ( i = j ) ∇ f j ( w t ) = j = 1 ∑ n n 1 ∇ f j ( w t ) = ∇ f ( w t ) So the batch average gradient is also just the global gradient. We can see this by applying linearity of expectation.
E [ g t ∣ w t ] = E [ 1 B ∑ b = 1 B ∇ f i b ( w t ) ∣ w t ] = 1 B ∑ b = 1 B E [ ∇ f i b ( w t ) ∣ w t ] = ∇ f ( w t ) \mathbf E[g_t \mid w_t] = \mathbf E\left[\frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t) \mid w_t\right] = \frac 1 B \sum^B_{b=1} \mathbf E\left[ \nabla f_{i_b}(w_t) \mid w_t\right] = \nabla f(w_t) E [ g t ∣ w t ] = E [ B 1 b = 1 ∑ B ∇ f i b ( w t ) ∣ w t ] = B 1 b = 1 ∑ B E [ ∇ f i b ( w t ) ∣ w t ] = ∇ f ( w t ) Therefore, we can rewrite the above inequality as
E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∇ f ( w t ) T ∇ f ( w t ) + α 2 L 2 E [ ∥ g t ∥ 2 ∣ w t ] E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∥ ∇ f ( w t ) ∥ 2 + α 2 L 2 E [ ∥ g t ∥ 2 ∣ w t ] \begin{align}
\mathbf E [ f(w_{t+1}) \mid w_t ] &\le
f(w_t) - \alpha \nabla f(w_t)^T \nabla f(w_t) + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]\\
\mathbf E [ f(w_{t+1}) \mid w_t ] &\le
f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]
\end{align} E [ f ( w t + 1 ) ∣ w t ] E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∇ f ( w t ) T ∇ f ( w t ) + 2 α 2 L E [ ∥ g t ∥ 2 ∣ w t ] ≤ f ( w t ) − α ∥∇ f ( w t ) ∥ 2 + 2 α 2 L E [ ∥ g t ∥ 2 ∣ w t ] Variance of Gradient ¶ Assumption on Variance ¶ Therefore, now all we have different is this last square term. Compare
GD: f ( w t + 1 ) ≤ f ( w t ) − α ∥ ∇ f ( w t ) ∥ 2 + α 2 L 2 ∥ ∇ f ( w t ) ∥ 2 SGD: E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∥ ∇ f ( w t ) ∥ 2 + α 2 L 2 E [ ∥ g t ∥ 2 ∣ w t ] \begin{align}
\text{GD: } f(w_{t+1}) &\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \| \nabla f(w_t) \|^2 \\
\text{SGD: }\mathbf E [ f(w_{t+1}) \mid w_t ] &\le
f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]
\end{align} GD: f ( w t + 1 ) SGD: E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∥∇ f ( w t ) ∥ 2 + 2 α 2 L ∥∇ f ( w t ) ∥ 2 ≤ f ( w t ) − α ∥∇ f ( w t ) ∥ 2 + 2 α 2 L E [ ∥ g t ∥ 2 ∣ w t ] To tweak this term, we need an extra assumption on the variance of the gradient:
the variance of the gradients is bounded: There exists a constant σ > 0 \sigma > 0 σ > 0 , such that for a uniformly randomly drawn sample i i i ,
E [ ∥ ∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] ≤ σ 2 \mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] \le \sigma^2 E [ ∥∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] ≤ σ 2 Meaning of this Bound ¶ Note this is just saying
V a r ( ∇ f ( w ) ) ≤ σ 2 Var(\nabla f(w)) \le \sigma^2 Va r ( ∇ f ( w )) ≤ σ 2 But we cannot use the V a r Var Va r notation because it is only for a scalar random variable, not for a vector RV. For the same reason, we need to prove the classic V a r ( X ) = E [ X 2 ] − E 2 [ X ] Var(X) = \mathbf E[X^2] - \mathbf E^2[X] Va r ( X ) = E [ X 2 ] − E 2 [ X ] before using it. We prove it in the following by employing the fact that each i i i will be drawn i.i.d.
E [ ∥ ∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] = 1 n ∑ i = 1 N ∥ ∇ f i ( w ) − ∇ f ( w ) ∥ 2 = 1 n ∑ i = 1 N ∥ ∇ f i ( w ) ∥ 2 − 2 n ∑ i = 1 N ∇ f i ( w ) T ∇ f ( w ) + 1 n ∑ i = 1 N ∥ ∇ f ( w ) ∥ 2 \begin{align}
\mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] &= \frac{1}{n} \sum_{i=1}^N \| \nabla f_i(w) - \nabla f(w) \|^2\\
&= \frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2
- \frac 2 n \sum^N_{i=1} \nabla f_i(w)^T \nabla f(w)
+ \frac 1 n \sum^N_{i=1} \|\nabla f(w)\|^2 \\
\end{align} E [ ∥∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] = n 1 i = 1 ∑ N ∥∇ f i ( w ) − ∇ f ( w ) ∥ 2 = n 1 i = 1 ∑ N ∥∇ f i ( w ) ∥ 2 − n 2 i = 1 ∑ N ∇ f i ( w ) T ∇ f ( w ) + n 1 i = 1 ∑ N ∥∇ f ( w ) ∥ 2 Note in the second term, from the expectation of a specific sample’s gradient E [ ∇ f i ( w t ) ] = ∇ f ( w t ) \mathbf E[\nabla f_i(w_t)] = \nabla f(w_t) E [ ∇ f i ( w t )] = ∇ f ( w t ) we have
2 ⋅ 1 n ∑ i = 1 N ∇ f i ( w ) T something = 2 ∇ f ( w ) T something 2 \cdot \frac 1 n \sum^N_{i=1} \nabla f_i(w)^T \text{ something} = 2 \nabla f(w) ^T \text{ something} 2 ⋅ n 1 i = 1 ∑ N ∇ f i ( w ) T something = 2∇ f ( w ) T something so we can rewrite the above equation as
E [ ∥ ∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] = 1 n ∑ i = 1 N ∥ ∇ f i ( w ) ∥ 2 − 2 ∇ f ( w ) T ∇ f ( w ) + ∥ ∇ f ( w ) ∥ 2 = ( 1 n ∑ i = 1 N ∥ ∇ f i ( w ) ∥ 2 ) − ∥ ∇ f ( w ) ∥ 2 = E [ ∥ ∇ f i ( w ) ∥ 2 ] − ∥ ∇ f ( w ) ∥ 2 = E [ ∥ ∇ f i ( w ) ∥ 2 ] − E 2 [ ∥ ∇ f ( w ) ∥ ] \begin{align}
\mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right]
&= \frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2
- 2 \nabla f(w) ^T \nabla f(w)
+ \|\nabla f(w)\|^2 \\
&= \left(\frac 1 n \sum^N_{i=1} \|\nabla f_i(w)\|^2 \right)
- \|\nabla f(w)\|^2 \\
&= \mathbf{E}\left[ \| \nabla f_i(w) \|^2 \right] - \| \nabla f(w) \|^2 \\
&= \mathbf{E}\left[ \| \nabla f_i(w) \|^2 \right] - \mathbf{E}^2\left[ \| \nabla f(w) \| \right] \\
\end{align} E [ ∥∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] = n 1 i = 1 ∑ N ∥∇ f i ( w ) ∥ 2 − 2∇ f ( w ) T ∇ f ( w ) + ∥∇ f ( w ) ∥ 2 = ( n 1 i = 1 ∑ N ∥∇ f i ( w ) ∥ 2 ) − ∥∇ f ( w ) ∥ 2 = E [ ∥∇ f i ( w ) ∥ 2 ] − ∥∇ f ( w ) ∥ 2 = E [ ∥∇ f i ( w ) ∥ 2 ] − E 2 [ ∥∇ f ( w ) ∥ ] ??? Add On to Variance ??? ¶ Applying Variance Bound ¶ Now we are prepared to look at this term
E [ ∥ g t ∥ 2 ∣ w t ] = E [ ∥ 1 B ∑ b = 1 B ∇ f i b ( w t ) ∥ 2 ∣ w t ] = 1 B 2 E [ ( ∑ b = 1 B ∇ f i b ( w t ) ) T ( ∑ c = 1 B ∇ f i c ( w t ) ) ∣ w t ] = 1 B 2 ∑ b = 1 B ∑ c = 1 B E [ ∇ f i b ( w t ) T ∇ f i c ( w t ) ∣ w t ] \begin{align}
\mathbf E [ \| g_t \|^2 | w_t ]
&= \mathbf E \left[\left\| \frac 1 B \sum^B_{b=1} \nabla f_{i_b}(w_t) \right\|^2 \mid w_t \right] \\
&= \frac 1 {B^2} \mathbf E \left[\left( \sum^B_{b=1} \nabla f_{i_b}(w_t) \right)^T \left( \sum^B_{c=1} \nabla f_{i_c}(w_t) \right) \mid w_t \right] \\
&= \frac 1 {B^2} \sum^B_{b=1} \sum^B_{c=1}\mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_c}(w_t) \mid w_t \right] \\
\end{align} E [ ∥ g t ∥ 2 ∣ w t ] = E ⎣ ⎡ ∥ ∥ B 1 b = 1 ∑ B ∇ f i b ( w t ) ∥ ∥ 2 ∣ w t ⎦ ⎤ = B 2 1 E ⎣ ⎡ ( b = 1 ∑ B ∇ f i b ( w t ) ) T ( c = 1 ∑ B ∇ f i c ( w t ) ) ∣ w t ⎦ ⎤ = B 2 1 b = 1 ∑ B c = 1 ∑ B E [ ∇ f i b ( w t ) T ∇ f i c ( w t ) ∣ w t ] Among these total B 2 B^2 B 2 pairs, we have B B B pairs with the same subscripts and B ( B − 1 ) B(B-1) B ( B − 1 ) pairs with distinct subscripts. For such pairs with distinct subscripts, we can apply linearity of expectation multiplication since they are drawn i.i.d. For pairs of the same subscript, note again expectation of ∥ ∇ f i ( w t ) ∥ 2 \| \nabla f_{i}(w_t)\|^2 ∥∇ f i ( w t ) ∥ 2 is just the average across all N N N training data.
E [ ∥ g t ∥ 2 ∣ w t ] = 1 B 2 ( ∑ b = 1 B ∑ c ≠ b B E [ ∇ f i b ( w t ) T ∇ f i c ( w t ) ∣ w t ] + ∑ b = 1 B E [ ∇ f i b ( w t ) T ∇ f i b ( w t ) ∣ w t ] ) = 1 B 2 ( ∑ b = 1 B ∑ c ≠ b B E [ ∇ f i b ( w t ) ∣ w t ] T E [ ∇ f i c ( w t ) ∣ w t ] + ∑ b = 1 B E [ ∥ ∇ f i b ( w t ) ∥ 2 ∣ w t ] ) = 1 B 2 ( ( B 2 − B ) ∥ ∇ f ( w t ) ∥ 2 + B 1 n ∑ i = 1 N ∥ ∇ f i ( w t ) ∥ 2 ) = B − 1 B ∥ ∇ f ( w t ) ∥ 2 + 1 B 1 n ∑ i = 1 N ∥ ∇ f i ( w t ) ∥ 2 \begin{align}
\mathbf E [ \| g_t \|^2 | w_t ]
&= \frac 1 {B^2} \left( \sum^B_{b=1} \sum^B_{c\not=b}\mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_c}(w_t) \mid w_t \right] + \sum^B_{b=1} \mathbf E \left[ \nabla f_{i_b}(w_t)^T \nabla f_{i_b}(w_t) \mid w_t \right] \right) \\
&= \frac 1 {B^2} \left( \sum^B_{b=1} \sum^B_{c\not=b}\mathbf E \left[ \nabla f_{i_b}(w_t) \mid w_t \right]^T \mathbf E\left[\nabla f_{i_c}(w_t) \mid w_t \right] + \sum^B_{b=1} \mathbf E \left[ \| \nabla f_{i_b}(w_t)\|^2 \mid w_t \right] \right) \\
&= \frac 1 {B^2} \left( (B^2-B) \| \nabla f(w_t)\|^2 + B \; \frac 1 n \sum^N_{i=1} \| \nabla f_{i}(w_t)\|^2 \right) \\
&= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \frac 1 n \sum^N_{i=1} \| \nabla f_{i}(w_t)\|^2 \\
\end{align} E [ ∥ g t ∥ 2 ∣ w t ] = B 2 1 ⎝ ⎛ b = 1 ∑ B c = b ∑ B E [ ∇ f i b ( w t ) T ∇ f i c ( w t ) ∣ w t ] + b = 1 ∑ B E [ ∇ f i b ( w t ) T ∇ f i b ( w t ) ∣ w t ] ⎠ ⎞ = B 2 1 ⎝ ⎛ b = 1 ∑ B c = b ∑ B E [ ∇ f i b ( w t ) ∣ w t ] T E [ ∇ f i c ( w t ) ∣ w t ] + b = 1 ∑ B E [ ∥∇ f i b ( w t ) ∥ 2 ∣ w t ] ⎠ ⎞ = B 2 1 ( ( B 2 − B ) ∥∇ f ( w t ) ∥ 2 + B n 1 i = 1 ∑ N ∥∇ f i ( w t ) ∥ 2 ) = B B − 1 ∥∇ f ( w t ) ∥ 2 + B 1 n 1 i = 1 ∑ N ∥∇ f i ( w t ) ∥ 2 From our assumption of bound on variance, we replace the last term with
E [ ∥ g t ∥ 2 ∣ w t ] = B − 1 B ∥ ∇ f ( w t ) ∥ 2 + 1 B 1 n ∑ i = 1 N ∥ ∇ f i ( w t ) ∥ 2 = B − 1 B ∥ ∇ f ( w t ) ∥ 2 + 1 B ( E [ ∥ ∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] + ∥ ∇ f ( w ) ∥ 2 ) ≤ B − 1 B ∥ ∇ f ( w t ) ∥ 2 + 1 B ( σ 2 + ∥ ∇ f ( w ) ∥ 2 ) ≤ ∥ ∇ f ( w t ) ∥ 2 + σ 2 B \begin{align}
\mathbf E [ \| g_t \|^2 | w_t ]
&= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \frac 1 n \sum^N_{i=1} \mathbf \| \nabla f_{i}(w_t)\|^2 \\
&= \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \left( \mathbf{E}\left[ \| \nabla f_i(w) - \nabla f(w) \|^2 \right] + \|\nabla f(w)\|^2 \right) \\
&\le \frac {B-1} {B} \| \nabla f(w_t)\|^2 + \; \frac 1 {B} \left( \sigma^2 + \|\nabla f(w)\|^2 \right) \\
&\le \| \nabla f(w_t)\|^2 + \; \frac {\sigma^2} {B} \\
\end{align} E [ ∥ g t ∥ 2 ∣ w t ] = B B − 1 ∥∇ f ( w t ) ∥ 2 + B 1 n 1 i = 1 ∑ N ∥ ∇ f i ( w t ) ∥ 2 = B B − 1 ∥∇ f ( w t ) ∥ 2 + B 1 ( E [ ∥∇ f i ( w ) − ∇ f ( w ) ∥ 2 ] + ∥∇ f ( w ) ∥ 2 ) ≤ B B − 1 ∥∇ f ( w t ) ∥ 2 + B 1 ( σ 2 + ∥∇ f ( w ) ∥ 2 ) ≤ ∥∇ f ( w t ) ∥ 2 + B σ 2 Finishing Up Derivation ¶ We can finally look back to our definition
E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∥ ∇ f ( w t ) ∥ 2 + α 2 L 2 E [ ∥ g t ∥ 2 ∣ w t ] ≤ f ( w t ) − α ∥ ∇ f ( w t ) ∥ 2 + α 2 L 2 ( ∥ ∇ f ( w t ) ∥ 2 + σ 2 B ) ≤ f ( w t ) − α ( 1 − α L 2 ) ∥ ∇ f ( w t ) ∥ 2 + α 2 σ 2 L 2 B \begin{align}
\mathbf E [ f(w_{t+1}) \mid w_t ] &\le
f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \mathbf E [ \| g_t \|^2 | w_t ]\\
&\le f(w_t) - \alpha \| \nabla f(w_t) \|^2 + \frac{\alpha^2 L}{2} \left( \| \nabla f(w_t) \|^2 + \frac {\sigma^2} {B} \right)\\
&\le f(w_t) - \alpha \left( 1 - \frac{\alpha L}{2} \right) \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B}
\end{align} E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α ∥∇ f ( w t ) ∥ 2 + 2 α 2 L E [ ∥ g t ∥ 2 ∣ w t ] ≤ f ( w t ) − α ∥∇ f ( w t ) ∥ 2 + 2 α 2 L ( ∥∇ f ( w t ) ∥ 2 + B σ 2 ) ≤ f ( w t ) − α ( 1 − 2 αL ) ∥∇ f ( w t ) ∥ 2 + 2 B α 2 σ 2 L Remember we have old assumption that α L ≤ 1 \alpha L \le 1 αL ≤ 1 , so
E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α 2 ∥ ∇ f ( w t ) ∥ 2 + α 2 σ 2 L 2 B \mathbf{E}[ f(w_{t+1}) \mid w_t] \le f(w_t) - \frac{\alpha}{2} \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B} E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − 2 α ∥∇ f ( w t ) ∥ 2 + 2 B α 2 σ 2 L PL Condition ¶ Let’s focus on function f f f that satisfies PL condition. So we add our 4th assumption:
∥ ∇ f ( x ) ∥ 2 ≥ 2 μ ( f ( x ) − f ∗ ) \left\| \nabla f(x) \right\|^2 \ge 2 \mu \left( f(x) - f^* \right) ∥ ∇ f ( x ) ∥ 2 ≥ 2 μ ( f ( x ) − f ∗ ) With this, we can write
E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − α 2 ∥ ∇ f ( w t ) ∥ 2 + α 2 σ 2 L 2 B ≤ f ( w t ) − α μ ( f ( w t ) − f ∗ ) + α 2 σ 2 L 2 B \begin{align}
\mathbf{E}[ f(w_{t+1}) \mid w_t]
&\le f(w_t) - \frac{\alpha}{2} \| \nabla f(w_t) \|^2 + \frac{\alpha^2 \sigma^2 L}{2B}\\
&\le f(w_t) - \alpha \mu \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\
\end{align} E [ f ( w t + 1 ) ∣ w t ] ≤ f ( w t ) − 2 α ∥∇ f ( w t ) ∥ 2 + 2 B α 2 σ 2 L ≤ f ( w t ) − αμ ( f ( w t ) − f ∗ ) + 2 B α 2 σ 2 L Subtract a f ∗ f^* f ∗ on both sides of the inequality. On the left side, it goes straight into the expectation because of linearity of expectation
E [ f ( w t + 1 ) ∣ w t ] − f ∗ ≤ ( f ( w t ) − f ∗ ) − α μ ( f ( w t ) − f ∗ ) + α 2 σ 2 L 2 B E [ f ( w t + 1 ) − f ∗ ∣ w t ] ≤ ( 1 − α μ ) ( f ( w t ) − f ∗ ) + α 2 σ 2 L 2 B \begin{align}
\mathbf{E}[ f(w_{t+1}) \mid w_t] - f^*
&\le (f(w_t) - f^*) - \alpha \mu \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\
\mathbf{E}[ f(w_{t+1}) - f^*\mid w_t]
&\le (1 - \alpha \mu) \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B}\\
\end{align} E [ f ( w t + 1 ) ∣ w t ] − f ∗ E [ f ( w t + 1 ) − f ∗ ∣ w t ] ≤ ( f ( w t ) − f ∗ ) − αμ ( f ( w t ) − f ∗ ) + 2 B α 2 σ 2 L ≤ ( 1 − αμ ) ( f ( w t ) − f ∗ ) + 2 B α 2 σ 2 L To get rid of the “given w t w_t w t ” term, we apply the Law of Total Expectation E [ E [ X ∣ Y ] ] = E [ X ] \mathbf E[\mathbf E [X | Y]] = \mathbf E[X] E [ E [ X ∣ Y ]] = E [ X ] and take expected value on both sides
E [ E [ f ( w t + 1 ) − f ∗ ∣ w t ] ] ≤ E [ ( 1 − α μ ) ( f ( w t ) − f ∗ ) + α 2 σ 2 L 2 B ] E [ f ( w t + 1 ) − f ∗ ] ≤ ( 1 − α μ ) E [ f ( w t ) − f ∗ ] + α 2 σ 2 L 2 B \begin{align}
\mathbf{E}\left[ \mathbf{E}[ f(w_{t+1}) - f^*\mid w_t] \right]
&\le \mathbf{E}\left[ (1 - \alpha \mu) \left( f(w_t) - f^* \right) + \frac{\alpha^2 \sigma^2 L}{2B} \right] \\
\mathbf{E}[ f(w_{t+1}) - f^*]
&\le (1 - \alpha \mu) \mathbf{E}\left[f(w_t) - f^* \right] + \frac{\alpha^2 \sigma^2 L}{2B} \\
\end{align} E [ E [ f ( w t + 1 ) − f ∗ ∣ w t ] ] E [ f ( w t + 1 ) − f ∗ ] ≤ E [ ( 1 − αμ ) ( f ( w t ) − f ∗ ) + 2 B α 2 σ 2 L ] ≤ ( 1 − αμ ) E [ f ( w t ) − f ∗ ] + 2 B α 2 σ 2 L Call ρ t = E [ f ( w t ) − f ∗ ] \rho_t = \mathbf{E}\left[f(w_t) - f^* \right] ρ t = E [ f ( w t ) − f ∗ ]
ρ t + 1 ≤ ( 1 − α μ ) ρ t + α 2 σ 2 L 2 B \rho_{t+1}
\le (1 - \alpha \mu) \rho_t + \frac{\alpha^2 \sigma^2 L}{2B} ρ t + 1 ≤ ( 1 − αμ ) ρ t + 2 B α 2 σ 2 L Observe when t → ∞ t \to \infin t → ∞ ,
ρ ∞ = ( 1 − α μ ) ρ ∞ + α 2 σ 2 L 2 B ρ ∞ = α σ 2 L 2 μ B \rho_{\infin} = (1 - \alpha \mu) \rho_\infin + \frac{\alpha^2 \sigma^2 L}{2B}\\
\rho_{\infin} = \frac{\alpha \sigma^2 L} {2 \mu B} ρ ∞ = ( 1 − αμ ) ρ ∞ + 2 B α 2 σ 2 L ρ ∞ = 2 μ B α σ 2 L If we subtract this value in both sides of the inequality
ρ t + 1 − α σ 2 L 2 μ B ≤ ( 1 − α μ ) ρ t + α 2 σ 2 L 2 B − α σ 2 L 2 μ B ≤ ( 1 − α μ ) ( ρ t − α σ 2 L 2 μ B ) \begin{align}
\rho_{t+1} - \frac{\alpha \sigma^2 L} {2 \mu B}
&\le (1 - \alpha \mu) \rho_t + \frac{\alpha^2 \sigma^2 L}{2B} - \frac{\alpha \sigma^2 L} {2 \mu B} \\
&\le (1 - \alpha \mu) (\rho_t - \frac{\alpha \sigma^2 L} {2 \mu B}) \\
\end{align} ρ t + 1 − 2 μ B α σ 2 L ≤ ( 1 − αμ ) ρ t + 2 B α 2 σ 2 L − 2 μ B α σ 2 L ≤ ( 1 − αμ ) ( ρ t − 2 μ B α σ 2 L ) Now we are ready to look at what happens when we start from initialization t = 0 t=0 t = 0 and runs T T T iterations
ρ t − α σ 2 L 2 μ B ≤ ( 1 − α μ ) T ( ρ 0 − α σ 2 L 2 μ B ) ≤ ( 1 − α μ ) T ρ 0 \begin{align}
\rho_{t} - \frac{\alpha \sigma^2 L} {2 \mu B}
&\le (1 - \alpha \mu)^T (\rho_0 - \frac{\alpha \sigma^2 L} {2 \mu B}) \\
&\le (1 - \alpha \mu)^T \rho_0 \\
\end{align} ρ t − 2 μ B α σ 2 L ≤ ( 1 − αμ ) T ( ρ 0 − 2 μ B α σ 2 L ) ≤ ( 1 − αμ ) T ρ 0 We apply the same trick of 1 − x ≤ e − x 1-x \le e^{-x} 1 − x ≤ e − x
E [ f ( w T ) − f ∗ ] − α σ 2 L 2 μ B ≤ ( 1 − α μ ) T E [ f ( w 0 ) − f ∗ ] ≤ e x p ( − α μ T ) E [ f ( w 0 ) − f ∗ ] \begin{align}
\mathbf{E}\left[f(w_T) - f^* \right] - \frac{\alpha \sigma^2 L} {2 \mu B}
&\le (1 - \alpha \mu)^T \mathbf{E}\left[f(w_0) - f^* \right] \\
&\le exp(-\alpha \mu T) \; \mathbf{E}\left[f(w_0) - f^* \right] \\
\end{align} E [ f ( w T ) − f ∗ ] − 2 μ B α σ 2 L ≤ ( 1 − αμ ) T E [ f ( w 0 ) − f ∗ ] ≤ e x p ( − αμ T ) E [ f ( w 0 ) − f ∗ ] In conclusion : (f ( w 0 ) f(w_0) f ( w 0 ) only depends on initialization)
E [ f ( w T ) − f ∗ ] ≤ e x p ( − α μ T ) ( f ( w 0 ) − f ∗ ) + α σ 2 L 2 μ B \mathbf{E}\left[f(w_T) - f^* \right]
\le exp(-\alpha \mu T) \; (f(w_0) - f^*) + \frac{\alpha \sigma^2 L} {2 \mu B} E [ f ( w T ) − f ∗ ] ≤ e x p ( − αμ T ) ( f ( w 0 ) − f ∗ ) + 2 μ B α σ 2 L Comparing this with what we had with normal Gradient Descent
f ( w T ) − f ∗ ≤ e x p ( − μ T L ) ( f ( w 0 ) − f ∗ ) f(w_{T}) - f^* \le exp(-\frac {\mu T} L) \left( f(w_0) - f^* \right) f ( w T ) − f ∗ ≤ e x p ( − L μ T ) ( f ( w 0 ) − f ∗ ) In GD, the loss gap goes to 0 exponentially. In SGD, There is this α σ 2 L 2 μ B \frac{\alpha \sigma^2 L} {2 \mu B} 2 μ B α σ 2 L error term caused by running GD stochastically. The loss gap flattens out when reaches this noise ball.
One thing to notice that is if we have a fixed learning rate, this noise ball will evaluate to a constant and we can never reach the minimum. However, if we have an adaptive learning rate that decreases as a function of step, we can actually push this noise ball to 0 and reach the minimum at the end.