Yao Lirong's Blog

Matryoshka Representation Learning, Adaptive Retrieval and Binary Vector Search

2024/12/25
loading

Intro to Matryoshka Representation Learning

In Matryoshka Representation Learning (MRL), we want to construct an encoding ed with dimension d such that its truncations of different lengths (ed/16, ed/8, ed/4, ed/2​) are each (somewhat) valid representations. Suppose you’re training on a classification problem with the classic encoder + classifier head architecture. At train time:

  • classic setting: you just use the vector ed as input to the classifier head
  • MRL: construct multiple classifier heads (in our case 5) and put one on top of encoding of each length (ed/16, …, ed) and average the loss of each classifier head. So we build heads of size [d, num_class], [d/2, num_class], ... [d/16, num_class] Note these classifier heads share weights.

Application: Adaptive Retrieval

Online retrieval is one of the tasks where latency matters the most. Given a user query q, it is slow to compute KNN from a dataset of size 1M (106) indexes if each index has dimension 3072. With MRL, we can decompose the process into two stages:

  1. Shortlist: First retrieve 2K indexes where the distance is computed using only 1024-d vector (the first 1024 elements of the 3072 vector)
  2. Rerank: Find KNN among these 2K indexes where the distance is computed using the full length 3072 vector

The FLOP is therefore reduced from 3072 × 106 to 1024 × 106 + 3072 × 2K. Ce Gao tested full length 3072-dim vector vs adaptive retrieval using Matryoshka 1024-dim. The accuracy dropped from 99% to 89% with Requests Per Second (RPS) raises from 300 to 1000.

Find more details of Matryoshka Representation Learning and its applications in this wonderful blog post. Read from section What is MRL? (Really this Time)

Binary Vector Search

Ce Gao suggested another way to reduce memory and FLOP use. He proposes to turn the length d FP32 vector into a length d binary vector, where original positive value is set to 1 and original negative value is set to 0.

Without using adaptive retrieval, the accuracy dropped from 99% to 83%, but the latency (RPS = 3000) and memory has a significant improvement because previously one single vector / encoding consists of d 32-bit number, whereas now it only consists of d 1-bit number.

If you adapt the Adaptive Retrieval setup mentioned earlier:

  1. Shortlist: retrieve 2K indexes using full-length but binary vector
  2. Rerank: find KNN among 2K indexes using full-length, FP32 vector

you get a precision drop from 99% only to 96% with an RPS of 1700.

P.S. I discovered this method on Simon Willison’s blog.

CATALOG