Before this, see <ahref="#2024/06/17-Conducting-Multi-Round-Conversation-with-Transformers">2024/06/17Conducting Multi-Round Conversation with Transformers</a> for why weneed cache. But we have query, key, value three matrices. Why do youonly cache past keys and values? How about past queries?

<h2 id="attention-mechanism-in-detail">Attention Mechanism inDetail</h2>Recall the attention process in transformer can be written in thefollowing matrix form: \(Z = \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\) If we look a particular output at position <spanclass=”math inline”>i</span>, it can be written as: \(z_i =({}<p>)</p><pre><code>\begin&#123;bmatrix&#125;v_1 \\v_2 \\\vdots \\v_n\end&#123;bmatrix&#125;</code></pre><p>\) A simple example can be found in the famous <ahref=”https://jalammar.github.io/illustrated-transformer/”>IllustratedTransformer</a></p>

<img src=”https://jalammar.github.io/images/t/self-attention-output.png”alt=”self attention output” /><figcaption aria-hidden="true">self attention output</figcaption>
<p>From the formula and the example, we can see that key and values arealways a pair in calculation. In fact, this is aligned with the veryconcept of soft dictionary behind attention: we get a query fromsomewhere and look at all the keys in the dictionaries to find, for eachkey, how much it relates to this query and output the weighted averageof each key’s value based on the relatedness.</p><h2 id="generative-transformer-decoder-based">Generative Transformer(Decoder Based)</h2>
<imgsrc=”https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/encoder_decoder/EncoderDecoder.png”alt=”Autoregressive Decoder” /><figcaption aria-hidden="true">Autoregressive Decoder</figcaption>
<p>Let’s consider a causal language model, aka a transformer’sautoregressive generative decoder. At inference time, we onlycare about the output at the last position because the model isautoregressive and the outputs at all the previous positions are exactlythe same as our input. (See the above graph from blogpost <ahref=”https://huggingface.co/blog/encoder-decoder”>Transformers-basedEncoder-Decoder Models</a>) Therefore, if the current sequence haslength s, we only care aboutzs. Allthe other outputs <spanclass=”math inline”>z1…s − 1</span> areuseless.</p><p><ahref=”https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L188-L191”>Inferencecode in Karpathy’s nanoGPT</a> corroborated this in its inference timeimplementation:</p>
<table><tr><td class="gutter"><pre>1
2
3
4
</pre></td><td class="code"><pre>if targets is None:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
</pre></td></tr></table>
Now revisit the formula to calculate the output <spanclass=”math inline”>zs</span>: \(z_s =( {}<p>)</p><pre><code>\begin&#123;bmatrix&#125;v_1 \\v_2 \\\vdots \\v_s\end&#123;bmatrix&#125;</code></pre><p>\) It should be clear that to save computation, we only need to cachethe kv values in the previous positions and it’s useless tocache the q value.</p><p>To give a more detailed example, let’s consider the whole process togenerate a sequence of tokens with length <spanclass=”math inline”>n</span>: <spanclass=”math inline”>t1, …, tn</span>.We can see the previous queries are never used in the computation.
\(\)</p><h2 id="time-complexity-boost">Time Complexity Boost</h2><p>People complain about the slow inference time of generativetransformer model, where it has a quadratic sequence length term <spanclass=”math inline”>O(s2)</span>. Thisquadratic term is caused by <spanclass=”math inline”>QKT</span>matrix multiplication in attention where both matrices have shape <spanclass=”math inline”>s × d</span>. Recall running timeof matmul AB where$A \in \R^{m \times p}, B \in \R^{p \timesn}$ is <spanclass=”math inline”>O(mpn)</span>,so this matmul of query and key matrix has time complexity <spanclass=”math inline”>O(s2d)</span>.</p><p>However, by observing that we only need the output at the very lastposition in generative model and utilizing KV-cache, we reduce ourmatrix $Q \in \R^{s \times d}$ to asingle vector of $q \in \R^{1 \timesd}$ and effectively reduce the time complexity of this operationto O(sd).Therefore, we can eliminate the quadratic term from our inference timeand only need linear time sinstead.</p><h2 id="what-about-encoder-based-transformer-model">What about EncoderBased Transformer Model?</h2><p>Encoder Based transformer models do not have the issue of repeatedlycomputing the same past tokens’ attention scores so do not need aKV-cache.</p><h2 id="code-implementation">Code Implementation</h2><p>Facebook’s <ahref=”https://github.com/facebookresearch/XLM”>cross-lingual languagemodel (XLM)</a> gives a fantastic example of how to implement KV-Cache(or transformers in general, it provides abundant comments of tensorshape at each step).</p><ol type="1"><li><p>At inference time, do not recompute elements (whereslen or a more descriptive naming can becached_sequence_length is how many previous positions havebeen cached): <ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L373-L380”>link</a></p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
</pre></td><td class="code"><pre>if cache is not None:
_slen = slen - cache['slen']
x = x[:, -_slen:]
positions = positions[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
</pre></td></tr></table>
</li><li><p>Retrieve, use and update cache: <ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L199-L207”>link1</a><ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L423”>link2</a></p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
</pre></td><td class="code"><pre>if self.layer_id in cache:
k_, v_ = cache[self.layer_id]
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
cache[self.layer_id] = (k, v)

cache['slen'] += tensor.size(1)
</pre></td></tr></table>
</li></ol><p>XLM can serve multiple purposes including as a generative causallanguage model, masked language model, or a translation language model.We use KV-Cache only with causal language model ingenerate() function, see <ahref=”https://github.com/facebookresearch/XLM/blob/cd281d32612d145c6742b4d3f048f80df8669c30/xlm/model/transformer.py#L482-L498”>code</a>.</p><p>XLM has a Memory module that implements <ahref=”https://github.com/facebookresearch/XLM#v-product-key-memory-layers-pkm”>Product-KeyMemory Layers</a> whose mechanism rings very familiar to me but I can’trecall where I’ve encountered something similar before. Anyway, you canignore those Memory implementations and focus on theattention part if use it as a source to learn cache or the basics ofattention.</p><h2 id="more-code-examples">More Code Examples</h2><ul><li>This Medium post <ahref=”https://medium.com/@plienhar/llm-inference-series-3-kv-caching-unveiled-048152e461c8”>KVcaching explained</a> leads way to where to find Hugging Face’simplementation in general, which can be too modular and abstractnowadays. It’s hidden in the forward function inXXXForCausalLM. Take <ahref=”https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/llama2#transformers.LlamaForCausalLM”>LlamaForCausalLM</a>as an example, in its forward function, we still need to godown the abstraction to LlamaModel ->LlamaDecoderLayer -> <ahref=”https://github.com/huggingface/transformers/blob/6c1d0b069de22d7ed8aa83f733c25045eea0585d/src/transformers/models/llama/modeling_llama.py#L337-L340”>LlamaAttention</a>and we can see the past_key_value there implementing theCache class. I didn’t read into how Hugging Face didit.</li><li>This Zhihu postexplaining KV-Cache leads way to <ahref=”https://github.com/huggingface/transformers/blob/d1a1bcf56aeb8593b9cc613b21422e6311875599/src/transformers/models/gpt2/modeling_gpt2.py#L318-L321”>HuggingFace’s GPT-2</a>. The <ahref=”https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/model.py#L105-L108”>originalGPT-2 code</a> is in fact more straightforward, but you’d better justread XLM. It simply has more comments and the naming is moreself-explanatory.</li></ul><h2 id="ps">PS</h2><p>I initially didn’t find where Hugging Face implemented KV-Cache incurrent version (transformer 4.40) but only this <ahref=”https://github.com/huggingface/transformers/blob/aec1ca3a588bc6c65f7886e3d3eaa74901a6356f/src/transformers/cache_utils.py#L293”>Cacheclass</a> and failed to find where it’s used. So I followed therecommendation under <ahref=”https://zhuanlan.zhihu.com/p/601044938”>this Zhihu post</a> to goto transformer 2.5.0 instead. A quick search like “kv” or “cache” led meto <ahref=”https://github.com/huggingface/transformers/blob/v2.5.0/src/transformers/modeling_xlm.py”>modeling_xlm.py</a>.I was surprised to find early Hugging Face model code was more of arename of original implementation instead of a refactor they do now.</p><p>I then read this <ahref=”https://medium.com/@plienhar/llm-inference-series-3-kv-caching-unveiled-048152e461c8”>KVcaching explained</a> post. Its graph isn’t super straightforward but itintroduces how KV-cache reduces time complexity and where to findHugging Face’s implementation.</p>