Yao Lirong's Blog

Conducting Multi-Round Conversation with Transformers

2024/06/17
loading

I was using LLaVA to query in an image how many characters there are. For higher accuracy, I decided to employ Chain of Thought, but struggled to implement it. CoT is conducted through a multiple round conversation. It is easily done in a graphical chat interface but how is it done internally with code?

Token Level

Before diving into instruct / chat model, let’s go to the lowest level and think how transformers do generation. Transformer is an autoregressive model: it uses its own output as input for the next round. Looking at nanoGPT’s generate function:

1
2
3
4
5
6
7
8
9
10
11
12
13
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
"""
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx

If we ignore the details, this for loop is effectively doing:

1
2
3
4
5
6
7
8
9
token0 = tokenizer(text)
output1 = model(token0)

token1 = get_resposne(output1)
output2 = model(token0 + token1)

token2 = get_resposne(output2)
output3 = model(token0 + token1 + token2)
...

By writing it out like this, it’s clear that each turn of generation, we feed the previous step input into the model as something new, though exactly the same. Therefore, when we call model(token0 + token1), we forgot about all the attention we calculated in model(token0) even though the attention for token0 part is actually completely the same. This is why people complain transformer inference is slow and this is where the inference speed-up techniques like KV-cache comes in.

This also reveals that the very popular graph demonstrating the theory behind transformer’s inference lied (at least to me). When calculate yi + 1, we do not re-use y0yi or the attention or the activations in the middle. We just re-feed them back into the model as something completely new.

Autoregressive Decoder

Conversation Level

Chat model is also just a text continuation model except it follows a chat template distinguishing which texts are inputted by the user and which are generated by the assistant. In the lowest abstraction level - the token level, for each turn, the model outputs one token and uses that as part of the input in next turn’s generation. One abstraction level higher to this conversation level, to do multiple-round conversation, a chat model similarly outputs one response to one user’s input and uses that response as a part of the input for next turn’s generation. Therefore, to conduct conversation with a chat model, we just append the model’s response at each turn to its corresponding input.

1
2
3
4
5
6
input1 = tokenizer(text1)
output1 = model(input1)
# output1 contains input1 and model's response 1
response1 = get_resposne(output1)
input2 = tokenizer(text2)
output2 = model(input1 + response1 + input2)

And yes, this means to get output2, we feed input1 + response1 both as new to the model, but this shouldn’t be a concern anymore since we feed each token as new anyway.

get_response

The question now comes to how we should implement get_response to extract the assistant’s response from the text-continuation model’s output.

  • Find the indicator (prefix) of the start of assistant’s message: Note when the model doesn’t follow the instruction and failed to generate such a prefix, this method fails.

    1
    2
    3
    4
    5
    6
    prefix = "\[/INST\]" # escape special characters for regex
    with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens = 300)
    detoked_output = processor.decode(output[0], skip_special_tokens=True)
    answer_idx = [m.end() for m in re.finditer(prefix, detoked_output)][-1]
    answer = detoked_output[answer_idx:]
  • recommended - Get the substring that is after the input (prompt): Hugging Face uses this approach in their TextGenerationPipeline. There’s a clean_up_tokenization_spaces variable in decode function which defaults to False. (For what it does, see this discussion) Hugging Face set it to True in both call, but I tried set both to False or one to True the other to False, either can give correct results. That said, it’s still best to follow what Hugging Face wrote. After all they know their codes best.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens = 300)
    detoked_output = processor.decode(output[0], skip_special_tokens=True,
    clean_up_tokenization_spaces=True)
    cutoff = len(text_processor.decode(
    inputs["input_ids"][0],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
    ))
    answer = detoked_output[cutoff:]

I had some trouble with this recommended approach at first:

1
2
3
4
5
6
7
8
chat = [
{"role": "user", "content": "<image>\nHow many animated characters are there in this image?"}
]
prompt = text_processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to(device)
...
detoked_output = processor.decode(output[0], skip_special_tokens=True)
cutoff = len(prompt)

And cutoff is actually many indexes after the real starting point of assistant’s response. That is because when we apply_chat_template, we added some special tokens <s> <\s> to indicate the start and end of one turn of conversation with the assistant, but when we detokenize the output, we skip_special_tokens to get the response only and caused this discrepancy.

I thought at first that this discrepancy comes from LLaVA replaced <image> token with the image embeddings (or pixel_values as Hugging Face calls it) because <image> also disappeared in the detoked_output. However, after reading LLaVA’s paper: Visual Instruction Tuning Figure 1: LLaVA network architecture, I realized LLaVA actually puts the image in front of the text input instead of inserting it in the middle.

LLaVA architecture

And <image> disappeared because it’s also a special token. However it was not inside the tokenizer.all_special_tokens. Reading the source code of tokenizer, I’m actually not sure how it was added as a special token so was not able to debug why it’s not in all_special_tokens. For this specific behavior, I submitted an issue on Hugging Face forum.

You can find chat template definition in tokenizer_config.json -> "chat_template". Also in this file, "added_tokens_decoder" attribute defines <image> as a special token.

The Complete Code

I referenced Hugging Face conversation pipeline for the general structure and the response extractor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
queries = [
"<image>\nHow many animated characters are there in this image?",
"Answer with a single number in decimal format. Give no explanations."
]

def generate_response(image):
chat = []
for query in queries:
chat.append({"role": "user", "content": query})
prompt = text_processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to(device)

with torch.no_grad():
output = model.generate(**inputs, max_new_tokens = 300)
output = processor.decode(output[0], skip_special_tokens=True)

input_ids = inputs["input_ids"]
cutoff = len(text_processor.decode(
input_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
))
answer = output[cutoff:]
chat.append({"role": "assistant", "content": answer})
return answer

PS

As written at the start of this blogpost, it all began from me trying to do multi-round conversation with a transformer. A web search took me to these discussions (link 1, link 2). It’s obvious this accepted approach of appending output to previous message causes great waste of computing resources, which made me realize how transform works internally at the lowest level is itself a waste of resources.

CATALOG
  1. 1. Token Level
  2. 2. Conversation Level
  3. 3. get_response
  4. 4. Detours when Taking the Recommended Approach
  5. 5. The Complete Code
  6. 6. PS