Intro to Matryoshka Representation Learning

In Matryoshka Representation Learning (MRL), we want to construct anencoding <spanclass="math inline">ed</span> withdimension d such that itstruncations of different lengths (<spanclass="math inline">ed/16, ed/8, ed/4, ed/2</span>​)are each (somewhat) valid representations. Suppose you’re training on aclassification problem with the classic encoder + classifier headarchitecture. At train time:

Application: Adaptive Retrieval

Online retrieval is one of the tasks where latency matters the most.Given a user query q, it isslow to compute KNN from a dataset of size 1M (<spanclass="math inline">106</span>) indexes if each index hasdimension 3072. With MRL, we can decompose the process into twostages:

  1. Shortlist: First retrieve 2K indexes where the distance is computedusing only 1024-d vector (the first 1024 elements of the 3072vector)
  2. Rerank: Find KNN among these 2K indexes where the distance iscomputed using the full length 3072 vector

The FLOP is therefore reduced from <spanclass="math inline">3072 × 106</span> to <spanclass="math inline">1024 × 106 + 3072 × 2K</span>.Ce Gao tested full length 3072-dim vector vs adaptive retrieval usingMatryoshka 1024-dim. The accuracy dropped from 99% to 89% with RequestsPer Second (RPS) raises from 300 to 1000.

Find more details of Matryoshka Representation Learning and itsapplications in this wonderful <ahref="https://aniketrege.github.io/blog/2024/mrl/#what-is-mrl-really-this-time">blogpost. Read from section What is MRL? (Really this Time)</a>

Binary Vector Search

<ahref="https://blog.pgvecto.rs/my-binary-vector-search-is-better-than-your-fp32-vectors">CeGao suggested</a> another way to reduce memory and FLOP use. He proposesto turn the length d FP32vector into a length d binaryvector, where original positive value is set to 1 and original negativevalue is set to 0.

Without using adaptive retrieval, the accuracy dropped from 99% to83%, but the latency (RPS = 3000) and memory has a significantimprovement because previously one single vector / encoding consists ofd 32-bit number, whereas nowit only consists of d 1-bitnumber.

If you adapt the Adaptive Retrieval setup mentioned earlier:

  1. Shortlist: retrieve 2K indexes using full-length but binaryvector
  2. Rerank: find KNN among 2K indexes using full-length, FP32vector

you get a precision drop from 99% only to 96% with an RPS of1700.

P.S. I discovered this method on <ahref="https://simonwillison.net/2024/Mar/26/binary-vector-search/">SimonWillison’s blog</a>.