We borrow the results of decoder-only transformer models from OpenAI’s paper Scaling Laws for Neural Language Models Section 2.1
We use the following notations:
L = number of layers of transformer blocks (N in Attention is All You Need)
dmodel = dimension of the input & output of a transformer block, also the output of the text encoder and input of the decoder
dff = dimension of the feed-forward network’s bottleneck. We defined the feed-forward network as
fc1 = fc(d_model, d_ff), fc2 = fc(d_ff, d_model)
dattn = dimension of the multi-head attention output (In Attention is All You Need, we have h number of heads. Queries and keys have dimension dk. Values have dimension dv. In practice, we usually have dk = dv. dattn we have here is defined as dk × h)
Part | Parameters | Explanation |
---|---|---|
Embed | $n_{vocab}\times d_{model} \\ +n_{ctx} \times d_{model}$ | One word embedding matrix (mapping each token to corresponding embedding ) and one positional embedding matrix |
Attention: Q K V Matrix | L3dmodeldattn | WQ has shape (dmodel, dattn). There’re also WK and WV |
Attention: Multi-head Projection | $L d_{attn} d_{model} $ | After we concat the output from all heads, there’s one projection from all-head output to the final output. This is that matrix. It was defined as WO in Attention is All You Need 3.2.2. |
Feedforward Network | L2dmodeldff | Explained in the definition of dff above. |
Total (Non-Embedding) | 2Ldmodel(2dattn + dff) |
If we have the standard dattn = dmodel = dff/4, we can get N = 12Ldmodel2
Put this into practice, let’s calculate a rough estimate of number of parameters the vanilla transformer has. The vanilla transformer base, per the paper Attention is All You Need Table 3, L = 6, dmodel = 512, dff = 2048, dattn = h × dk = 8 × 64 = 512, nvocab = 37000. I didn’t find info about nctx, but is probably 512.
Note that different from OpenAI’s favorite decoder-only transformer, the vanilla transformer has an encoder-decoder architecture and the decoder block has an additional attention block. Therefore, the encoder has a total 2Ldmodel(2dattn + dff) parameters, the decoder has a total 2Ldmodel(4dattn + dff) params, and the embedding part has a total nvocab × dmodel + nctx × dmodel params. The final result is ∼ 63 × 106. I tried hard to figure out where went off from the paper’s 65 × 106 but had no luck. Adding the parameters of LayerNorm still didn’t even out the numbers. But it’s close enough so I’ll call it a day.