Yao Lirong's Blog

Parameter and FLOP Count in Transformer Model

2024/02/22
loading

We borrow the results of decoder-only transformer models from OpenAI’s paper Scaling Laws for Neural Language Models Section 2.1

We use the following notations:

  • L = number of layers of transformer blocks (N in Attention is All You Need)

  • dmodel = dimension of the input & output of a transformer block, also the output of the text encoder and input of the decoder

  • dff = dimension of the feed-forward network’s bottleneck. We defined the feed-forward network as fc1 = fc(d_model, d_ff), fc2 = fc(d_ff, d_model)

  • dattn = dimension of the multi-head attention output (In Attention is All You Need, we have h number of heads. Queries and keys have dimension dk. Values have dimension dv. In practice, we usually have dk = dv. dattn we have here is defined as dk × h)

Part Parameters Explanation
Embed $n_{vocab}\times d_{model} \\ +n_{ctx} \times d_{model}$ One word embedding matrix (mapping each token to corresponding embedding ) and one positional embedding matrix
Attention: Q K V Matrix L3dmodeldattn WQ has shape (dmodel, dattn). There’re also WK and WV
Attention: Multi-head Projection $L d_{attn} d_{model} $ After we concat the output from all heads, there’s one projection from all-head output to the final output. This is that matrix. It was defined as WO in Attention is All You Need 3.2.2.
Feedforward Network L2dmodeldff Explained in the definition of dff above.
Total (Non-Embedding) 2Ldmodel(2dattn + dff)

If we have the standard dattn = dmodel = dff/4, we can get N = 12Ldmodel2

Put this into practice, let’s calculate a rough estimate of number of parameters the vanilla transformer has. The vanilla transformer base, per the paper Attention is All You Need Table 3, L = 6, dmodel = 512, dff = 2048, dattn = h × dk = 8 × 64 = 512, nvocab = 37000. I didn’t find info about nctx, but is probably 512.

Note that different from OpenAI’s favorite decoder-only transformer, the vanilla transformer has an encoder-decoder architecture and the decoder block has an additional attention block. Therefore, the encoder has a total 2Ldmodel(2dattn + dff) parameters, the decoder has a total 2Ldmodel(4dattn + dff) params, and the embedding part has a total nvocab × dmodel + nctx × dmodel params. The final result is  ∼ 63 × 106. I tried hard to figure out where went off from the paper’s 65 × 106 but had no luck. Adding the parameters of LayerNorm still didn’t even out the numbers. But it’s close enough so I’ll call it a day.

CATALOG