Before this, see <ahref="#2024/06/17-Conducting-Multi-Round-Conversation-with-Transformers">2024/06/17Conducting Multi-Round Conversation with Transformers</a> for why weneed cache. But we have query, key, value three matrices. Why do youonly cache past keys and values? How about past queries?
<h2 id="attention-mechanism-in-detail">Attention Mechanism inDetail</h2>Recall the attention process in transformer can be written in thefollowing matrix form: \(Z = \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\) If we look a particular output at position <spanclass=”math inline”>i</span>, it can be written as: \(z_i =({}<p>)</p><pre><code>\begin{bmatrix}v_1 \\v_2 \\\vdots \\v_n\end{bmatrix}</code></pre><p>\) A simple example can be found in the famous <ahref=”https://jalammar.github.io/illustrated-transformer/”>IllustratedTransformer</a></p>
2
3
4
</pre></td><td class="code"><pre>if targets is None:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
</pre></td></tr></table>kv
values in the previous positions and it’s useless tocache the q
value.</p><p>To give a more detailed example, let’s consider the whole process togenerate a sequence of tokens with length <spanclass=”math inline”>n</span>: <spanclass=”math inline”>t1, …, tn</span>.We can see the previous queries are never used in the computation.
\(\)</p><h2 id="time-complexity-boost">Time Complexity Boost</h2><p>People complain about the slow inference time of generativetransformer model, where it has a quadratic sequence length term <spanclass=”math inline”>O(s2)</span>. Thisquadratic term is caused by <spanclass=”math inline”>QKT</span>matrix multiplication in attention where both matrices have shape <spanclass=”math inline”>s × d</span>. Recall running timeof matmul AB where$A \in \R^{m \times p}, B \in \R^{p \timesn}$ is <spanclass=”math inline”>O(mpn)</span>,so this matmul of query and key matrix has time complexity <spanclass=”math inline”>O(s2d)</span>.</p><p>However, by observing that we only need the output at the very lastposition in generative model and utilizing KV-cache, we reduce ourmatrix $Q \in \R^{s \times d}$ to asingle vector of $q \in \R^{1 \timesd}$ and effectively reduce the time complexity of this operationto O(sd).Therefore, we can eliminate the quadratic term from our inference timeand only need linear time sinstead.</p><h2 id="what-about-encoder-based-transformer-model">What about EncoderBased Transformer Model?</h2><p>Encoder Based transformer models do not have the issue of repeatedlycomputing the same past tokens’ attention scores so do not need aKV-cache.</p><h2 id="code-implementation">Code Implementation</h2><p>Facebook’s <ahref=”https://github.com/facebookresearch/XLM”>cross-lingual languagemodel (XLM)</a> gives a fantastic example of how to implement KV-Cache(or transformers in general, it provides abundant comments of tensorshape at each step).</p><ol type="1"><li><p>At inference time, do not recompute elements (whereslen
or a more descriptive naming can becached_sequence_length
is how many previous positions havebeen cached): <ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L373-L380”>link</a></p>
2
3
4
5
6
7
8
</pre></td><td class="code"><pre>if cache is not None:
_slen = slen - cache['slen']
x = x[:, -_slen:]
positions = positions[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
</pre></td></tr></table>
2
3
4
5
6
7
</pre></td><td class="code"><pre>if self.layer_id in cache:
k_, v_ = cache[self.layer_id]
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
cache[self.layer_id] = (k, v)
…
cache['slen'] += tensor.size(1)
</pre></td></tr></table>generate()
function, see <ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L482-L498”>code</a>.</p><p>XLM has a Memory
module that implements <ahref=”https://github.com/facebookresearch/XLM#v-product-key-memory-layers-pkm”>Product-KeyMemory Layers</a> whose mechanism rings very familiar to me but I can’trecall where I’ve encountered something similar before. Anyway, you canignore those Memory
implementations and focus on theattention part if use it as a source to learn cache or the basics ofattention.</p><h2 id="more-code-examples">More Code Examples</h2><ul><li>This Medium post <ahref=”https://medium.com/@plienhar/llm-inference-series-3-kv-caching-unveiled-048152e461c8”>KVcaching explained</a> leads way to where to find Hugging Face’simplementation in general, which can be too modular and abstractnowadays. It’s hidden in the forward
function inXXXForCausalLM
. Take <ahref=”https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/llama2#transformers.LlamaForCausalLM”>LlamaForCausalLM
</a>as an example, in its forward
function, we still need to godown the abstraction to LlamaModel
->LlamaDecoderLayer
-> <ahref=”https://github.com/huggingface/transformers/blob/6c1d0b069de22d7ed8aa83f733c25045eea0585d/src/transformers/models/llama/modeling_llama.py#L337-L340”>LlamaAttention
</a>and we can see the past_key_value
there implementing theCache
class. I didn’t read into how Hugging Face didit.</li><li>This Zhihu postexplaining KV-Cache leads way to <ahref=”https://github.com/huggingface/transformers/blob/d1a1bcf56aeb8593b9cc613b21422e6311875599/src/transformers/models/gpt2/modeling_gpt2.py#L318-L321”>HuggingFace’s GPT-2</a>. The <ahref=”https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L105-L108”>originalGPT-2 code</a> is in fact more straightforward, but you’d better justread XLM. It simply has more comments and the naming is moreself-explanatory.</li></ul><h2 id="ps">PS</h2><p>I initially didn’t find where Hugging Face implemented KV-Cache incurrent version (transformer 4.40
) but only this <ahref=”https://github.com/huggingface/transformers/blob/aec1ca3a588bc6c65f7886e3d3eaa74901a6356f/src/transformers/cache_utils.py#L293”>Cache
class</a> and failed to find where it’s used. So I followed therecommendation under <ahref=”https://zhuanlan.zhihu.com/p/601044938”>this Zhihu post</a> to goto transformer 2.5.0 instead. A quick search like “kv” or “cache” led meto <ahref=”https://github.com/huggingface/transformers/blob/v2.5.0/src/transformers/modeling_xlm.py”>modeling_xlm.py
</a>.I was surprised to find early Hugging Face model code was more of arename of original implementation instead of a refactor they do now.</p><p>I then read this <ahref=”https://medium.com/@plienhar/llm-inference-series-3-kv-caching-unveiled-048152e461c8”>KVcaching explained</a> post. Its graph isn’t super straightforward but itintroduces how KV-cache reduces time complexity and where to findHugging Face’s implementation.</p>