I was using LLaVA to query in an image how many characters there are.For higher accuracy, I decided to employ Chain of Thought, but struggledto implement it. CoT is conducted through a multiple round conversation.It is easily done in a graphical chat interface but how is it doneinternally with code?

<h2 id="token-level">Token Level</h2><p>Before diving into instruct / chat model, let’s go to the lowestlevel and think how transformers do generation. Transformer is anautoregressive model: it uses its own output as input for the nextround. Looking at <ahref=”https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L328”>nanoGPT’sgenerate function</a>:</p>

<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="code"><pre>def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
"""
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
</pre></td></tr></table>
<p>If we ignore the details, this for loop is effectively doing:</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
</pre></td><td class="code"><pre>token0 = tokenizer(text)
output1 = model(token0)

token1 = get_resposne(output1)
output2 = model(token0 + token1)

token2 = get_resposne(output2)
output3 = model(token0 + token1 + token2)

</pre></td></tr></table>
<p>By writing it out like this, it’s clear that each turn of generation,we feed the previous step input into the model as something new, thoughexactly the same. Therefore, when we callmodel(token0 + token1), we forgot about all theattention we calculated in model(token0) eventhough the attention for token0 part is actually completelythe same. This is why people complain transformer inference is slow andthis is where the inference speed-up techniques like KV-cache comesin.</p><p>This also reveals that the very popular graph demonstrating thetheory behind transformer’s inference lied (at least to me). Whencalculate <spanclass=”math inline”>yi + 1</span>, we donot re-use <spanclass=”math inline”>y0yi</span>or the attention or the activations in the middle. We just re-feed themback into the model as something completely new.</p>
<imgsrc=”https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/encoder_decoder/EncoderDecoder.png”alt=”Autoregressive Decoder” /><figcaption aria-hidden="true">Autoregressive Decoder</figcaption>
<h2 id="conversation-level">Conversation Level</h2><p>Chat model is also just a text continuation model except it follows achat template distinguishing which texts are inputted by the user andwhich are generated by the assistant. In the lowest abstraction level -the token level, for each turn, the model outputs one token and usesthat as part of the input in next turn’s generation. One abstractionlevel higher to this conversation level, to do multiple-roundconversation, a chat model similarly outputs one response to one user’sinput and uses that response as a part of the input for next turn’sgeneration. Therefore, to conduct conversation with a chat model, wejust append the model’s response at each turn to its correspondinginput.</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
</pre></td><td class="code"><pre>input1 = tokenizer(text1)
output1 = model(input1)
# output1 contains input1 and model's response 1
response1 = get_resposne(output1)
input2 = tokenizer(text2)
output2 = model(input1 + response1 + input2)
</pre></td></tr></table>
<p>And yes, this means to get output2, we feedinput1 + response1 both as new to the model, but thisshouldn’t be a concern anymore since we feed each token as newanyway.</p><h2 id="get_response">get_response</h2><p>The question now comes to how we should implementget_response to extract the assistant’s response from thetext-continuation model’s output.</p><ul><li><p>Find the indicator (prefix) of the start of assistant’s message:Note when the model doesn’t follow the instruction and failed togenerate such a prefix, this method fails.</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
</pre></td><td class="code"><pre>prefix = "[/INST]" # escape special characters for regex
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens = 300)
detoked_output = processor.decode(output[0], skip_special_tokens=True)
answer_idx = [m.end() for m in re.finditer(prefix, detoked_output)][-1]
answer = detoked_output[answer_idx:]
</pre></td></tr></table>
</li><li><p>recommended - Get the substring that is afterthe input (prompt): Hugging Face uses this approach in their <ahref=”https://github.com/huggingface/transformers/blob/1c1aec2ef1d6822fae3ffbb973b4c941f65f4ddf/src/transformers/pipelines/text_generation.py#L369-L387”>TextGenerationPipeline</a>.There’s a clean_up_tokenization_spaces variable indecode function which defaults to False. (Forwhat it does, see <ahref=”https://discuss.huggingface.co/t/what-does-the-parameter-clean-up-tokenization-spaces-do-in-the-tokenizer-decode-function/17399”>thisdiscussion</a>) Hugging Face set it to True in both call,but I tried set both to False or one to Truethe other to False, either can give correct results. Thatsaid, it’s still best to follow what Hugging Face wrote. After all theyknow their codes best.</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
</pre></td><td class="code"><pre>with torch.no_grad():
output = model.generate(**inputs, max_new_tokens = 300)
detoked_output = processor.decode(output[0], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
cutoff = len(text_processor.decode(
inputs["input_ids"][0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
))
answer = detoked_output[cutoff:]
</pre></td></tr></table>
</li></ul><h2 id="detours-when-taking-the-recommended-approach">Detours whenTaking the Recommended Approach</h2><p>I had some trouble with this recommended approach at first:</p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
</pre></td><td class="code"><pre>chat = [
{"role": "user", "content": "<image>\nHow many animated characters are there in this image?"}
]
prompt = text_processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to(device)

detoked_output = processor.decode(output[0], skip_special_tokens=True)
cutoff = len(prompt)
</pre></td></tr></table>
<p>And cutoff is actually many indexes after the realstarting point of assistant’s response. That is because when weapply_chat_template, we added some special tokens<s> <\s> to indicate the start and end of oneturn of conversation with the assistant, but when we detokenize theoutput, we skip_special_tokens to get the response only andcaused this discrepancy.</p><p>I thought at first that this discrepancy comes from LLaVA replaced<image> token with the image embeddings (orpixel_values as Hugging Face calls it) because<image> also disappeared in thedetoked_output. However, after reading LLaVA’s paper: <ahref=”https://arxiv.org/abs/2304.08485”>Visual Instruction Tuning</a>Figure 1: LLaVA network architecture, I realized LLaVA actually puts theimage in front of the text input instead of inserting it in themiddle.</p>
<img src=”https://arxiv.org/html/2304.08485v2/x1.png”alt=”LLaVA architecture” /><figcaption aria-hidden="true">LLaVA architecture</figcaption>
<p>And <image> disappeared because it’s also aspecial token. However it was not inside thetokenizer.all_special_tokens. Reading the source code oftokenizer, I’m actually not sure how it was added as a special token sowas not able to debug why it’s not in all_special_tokens.For this specific behavior, I submitted <ahref=”https://discuss.huggingface.co/t/additional-special-tokens-are-not-added/93192”>anissue on Hugging Face forum</a>.</p><p>You can find chat template definition intokenizer_config.json -> "chat_template". Also in thisfile, "added_tokens_decoder" attribute defines<image> as a special token.</p><h2 id="the-complete-code">The Complete Code</h2><p>I referenced Hugging Face conversation pipeline for <ahref=”https://huggingface.co/docs/transformers/main/conversations#what-happens-inside-the-pipeline”>thegeneral structure</a> and <ahref=”https://github.com/huggingface/transformers/blob/1c1aec2ef1d6822fae3ffbb973b4c941f65f4ddf/src/transformers/pipelines/text_generation.py#L369-L387”>theresponse extractor</a></p>
<table><tr><td class="gutter"><pre>1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="code"><pre>queries = [
"<image>\nHow many animated characters are there in this image?",
"Answer with a single number in decimal format. Give no explanations."
]

def generate_response(image):
chat = []
for query in queries:
chat.append({"role": "user", "content": query})
prompt = text_processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to(device)

with torch.no_grad():
output = model.generate(**inputs, max_new_tokens = 300)
output = processor.decode(output[0], skip_special_tokens=True)

input_ids = inputs["input_ids"]
cutoff = len(text_processor.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
))
answer = output[cutoff:]
chat.append({"role": "assistant", "content": answer})
return answer
</pre></td></tr></table>
<h2 id="ps">PS</h2><p>As written at the start of this blogpost, it all began from me tryingto do multi-round conversation with a transformer. A web search took meto these discussions (<ahref=”https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/19”>link1</a>, link2). It’s obvious this acceptedapproach of appending output to previous message causes great wasteof computing resources, which made me realize how transform works <ahref=”#Token-Level”>internally at the lowest level</a> is itself a wasteof resources.</p>