ColPali and PLAID- Part 2

Neeraj Kumar
8 min readSep 20, 2024

--

PLAID architecture :

Improvement over ColBERTv2 in terms of latency.

Latency is split into several stages: query encoding, candidate generation, index lookups (which involve retrieving compressed vector representations for candidate passages), residual decompression, and scoring (i.e., final MaxSim computations).

In the case of vanilla ColBERTv2, the primary bottlenecks are the index lookups and residual decompression processes. Index lookups are resource-intensive because they require significant memory bandwidth. Each vector is encoded with a 4-bit centroid ID and 32-byte residuals, and since each passage can contain many vectors, retrieving data for potentially up to 65,536 (²¹⁶) candidate passages becomes costly. Furthermore, vanilla ColBERTv2 must dynamically construct padded tensors to accommodate passages of varying lengths, which adds to the complexity.

Residual decompression is another time-consuming task, as it involves several computationally expensive operations such as unpacking bits and summing large values.

While ColBERTv2 reduces space usage by employing centroids, PLAID demonstrates that centroids can also accelerate search without sacrificing quality by acting as proxies for passage embeddings. This approach allows the model to skip low-scoring passages, avoiding the need to retrieve or decompress their residuals.

Although this adds a small overhead in the candidate generation stage, it leads to substantial time savings in later stages. The hypothesis is that retrieval based solely on centroids can identify high-scoring passages that would have been retrieved by the full ColBERTv2 method.

PLAID inference pipeline, comprising several consecutive stages for retrieval, filtering, and ranking.

  • The first stage generates an initial candidate set in candidate generation step by calculating relevance scores for each centroid in relation to the query embeddings.
  • In the intermediate stages, PLAID employs innovative techniques like centroid interaction and centroid pruning to efficiently filter candidate passages.
  • Finally, PLAID ranks(scoring)the remaining candidates using fully reconstructed passage embeddings.

Candidate Generation Step:

1. Query Embedding Matrix (Q): This is a matrix representing the query in a high-dimensional vector space.

2. List of Centroid Vectors : These are pre-computed vectors that represent clusters of token embeddings, created using k-means clustering.

3. Token-level Query-Centroid Relevance Scores (S_c,q) : These scores indicate how relevant each centroid is to each token in the query.

A passage is considered “close” to a centroid if one or more of its tokens were assigned to that centroid during the k-means clustering process.

Inverted List Structure:
Stores an inverted list mapping centroids to unique passage IDs. This is more space-efficient because there are fewer passages than embeddings.

Centroid Interaction:

Equation 2 that 𝑆𝑐,𝑞 contains the relevance scores for each centroid in relation to the query tokens. Let 𝐼 represent the list of centroid indices associated with each token in the candidate set. Additionally, let 𝑆𝑐,𝑞[𝑖] represent the 𝑖-th row of 𝑆𝑐,𝑞. Using this, PLAID constructs the approximate centroid-based scores, denoted as 𝐷˜.

Then to rank the candidate passages using 𝐷˜ , PLAID computes the MaxSim scores 𝑆𝐷˜ as

Example

1. Query: “Best Italian restaurants NYC”
2. Query Embedding Matrix ( Q ) (4 tokens, 6 dimensions):

Q = [0.1 & 0.2 & 0.3 & 0.4 & 0.5 & 0.6 \\
0.2 & 0.3 & 0.4 & 0.5 & 0.6 & 0.7 \\
0.3 & 0.4 & 0.5 & 0.6 & 0.7 & 0.8 \\
0.4 & 0.5 & 0.6 & 0.7 & 0.8 & 0.9
]

3. Centroid Vectors ( C ) (5 centroids, 6 dimensions):
C = [
0.1 & 0.2 & 0.1 & 0.2 & 0.1 & 0.2 \\
0.3 & 0.4 & 0.3 & 0.4 & 0.3 & 0.4 \\
0.5 & 0.6 & 0.5 & 0.6 & 0.5 & 0.6 \\
0.7 & 0.8 & 0.7 & 0.8 & 0.7 & 0.8 \\
0.9 & 1.0 & 0.9 & 1.0 & 0.9 & 1.0
]

4. Passage Tokens and Centroid Assignments:

- Passage 1: [0.1, 0.2, 0.1, 0.2, 0.1, 0.2] (C1), [0.5, 0.6, 0.5, 0.6, 0.5, 0.6] (C3)
- Passage 2: [0.9, 1.0, 0.9, 1.0, 0.9, 1.0] (C5), [0.3, 0.4, 0.3, 0.4, 0.3, 0.4] (C2)
- Passage 3: [0.7, 0.8, 0.7, 0.8, 0.7, 0.8] (C4), [0.1, 0.2, 0.1, 0.2, 0.1, 0.2] (C1)
- Passage 4: [0.3, 0.4, 0.3, 0.4, 0.3, 0.4] (C2), [0.7, 0.8, 0.7, 0.8, 0.7, 0.8] (C4)
- Passage 5: [0.5, 0.6, 0.5, 0.6, 0.5, 0.6] (C3), [0.9, 1.0, 0.9, 1.0, 0.9, 1.0] (C5)
- Passage 6: [0.7, 0.8, 0.7, 0.8, 0.7, 0.8] (C4), [0.5, 0.6, 0.5, 0.6, 0.5, 0.6] (C3)

Step-by-Step Process

1. Compute Relevance Scores (\(S_{c,q}\))
Calculate the relevance scores between the query tokens and the centroids:

S_{c,q} = C \cdot Q^T

S_{c,q} = [
0.42 & 0.54 & 0.66 & 0.78 \\
0.96 & 1.26 & 1.56 & 1.86 \\
1.50 & 1.98 & 2.46 & 2.94 \\
2.04 & 2.70 & 3.36 & 4.02 \\
2.58 & 3.42 & 4.26 & 5.10
]

2. Identify Top-(t) Centroids:
Suppose \(t = 2\) (top-2 centroids):

- For the first query token (“Best”):
- Relevance scores: [0.42, 0.96, 1.50, 2.04, 2.58]
- Top-2 centroids: \(C_5, C_4\)

- For the second query token (“Italian”):
- Relevance scores: [0.54, 1.26, 1.98, 2.70, 3.42]
- Top-2 centroids: \(C_5, C_4\)

- For the third query token (“restaurants”):
- Relevance scores: [0.66, 1.56, 2.46, 3.36, 4.26]
- Top-2 centroids: \(C_5, C_4\)

  • For the fourth query token (“NYC”):
    - Relevance scores: [0.78, 1.86, 2.94, 4.02, 5.10]
    - Top-2 centroids: \(C_5, C_4\)

3. Map Centroids to Passages

Identify passages that have tokens assigned to the top-2 centroids (\(C_4\) and \(C_5\)):
- Passage 1: [C1, C3] (not selected)
- Passage 2: [C5, C2] (selected)
- Passage 3: [C4, C1] (selected)
- Passage 4: [C2, C4] (selected)
- Passage 5: [C3, C5] (selected)
- Passage 6: [C4, C3] (selected)

4. Construct Centroid-based Approximate Scores (\(\tilde{D}\))

For selected passages, use the relevance scores of the top-2 centroids:

S_{c,q} = [
0.42 & 0.54 & 0.66 & 0.78 \\
0.96 & 1.26 & 1.56 & 1.86 \\
1.50 & 1.98 & 2.46 & 2.94 \\
2.04 & 2.70 & 3.36 & 4.02 \\
2.58 & 3.42 & 4.26 & 5.10
]

For each passage, we consider the relevance scores of the centroids assigned to its tokens. Let’s denote these relevance scores as \(\tilde{D}\):

- Passage 2 (Centroids \(C_5, C_2\)):
\tilde{D}_2 = [
2.58 & 3.42 & 4.26 & 5.10 \\
0.96 & 1.26 & 1.56 & 1.86
]

- Passage 3 (Centroids \(C_4, C_1\)):
\tilde{D}_3 = [
2.04 & 2.70 & 3.36 & 4.02 \\
0.42 & 0.54 & 0.66 & 0.78
]

- Passage 4 (Centroids \(C_2, C_4\)):
\tilde{D}_4 = [
0.96 & 1.26 & 1.56 & 1.86 \\
2.04 & 2.70 & 3.36 & 4.02
]

- Passage 5 (Centroids \(C_3, C_5\)):
\tilde{D}_5 = [
1.50 & 1.98 & 2.46 & 2.94 \\
2.58 & 3.42 & 4.26 & 5.10
]

- Passage 6 (Centroids \(C_4, C_3\)):
\tilde{D}_6 = [
2.04 & 2.70 & 3.36 & 4.02 \\
1.50 & 1.98 & 2.46 & 2.94
]

5. Compute MaxSim Scores

To compute the MaxSim scores for each passage, we use the formula:

Let’s compute this for each passage:

  • Passage 2:
    S_{\tilde{D}_2} = \max(2.58, 0.96) + \max(3.42, 1.26) + \max(4.26, 1.56) + \max(5.10, 1.86) = 2.58 + 3.42 + 4.26 + 5.10 = 15.36
  • Passage 3:
    S_{\tilde{D}_3} = \max(2.04, 0.42) + \max(2.70, 0.54) + \max(3.36, 0.66) + \max(4.02, 0.78) = 2.04 + 2.70 + 3.36 + 4.02 = 12.12
  • Passage 4:
    S_{\tilde{D}_4} = \max(0.96, 2.04) + \max(1.26, 2.70) + \max(1.56, 3.36) + \max(1.86, 4.02) = 2.04 + 2.70 + 3.36 + 4.02 = 12.12
  • Passage 5:
    S_{\tilde{D}_5} = \max(1.50, 2.58) + \max(1.98, 3.42) + \max(2.46, 4.26) + \max(2.94, 5.10) = 2.58 + 3.42 + 4.26 + 5.10 = 15.36
  • Passage 6:
    S_{\tilde{D}_6} = \max(2.04, 1.50) + \max(2.70, 1.98) + \max(3.36, 2.46) + \max(4.02, 2.94) = 2.04 + 2.70 + 3.36 + 4.02 = 12.12

6. Rank the Passages

Rank the candidate passages based on the computed scores:
1. Passage 2: \(S_{\tilde{D}_2} = 15.36\)
2. Passage 5: \(S_{\tilde{D}_5} = 15.36\)
3. Passage 3: \(S_{\tilde{D}_3} = 12.12\)
4. Passage 4: \(S_{\tilde{D}_4} = 12.12\)
5. Passage 6: \(S_{\tilde{D}_6} = 12.12\)

Centroid Pruning:

PLAID filters out tokens whose associated centroids do not meet a certain threshold, for relevance. In this filtering phase, only tokens whose highest centroid score exceeds the threshold t_cs are considered. Mathematically, the condition for a centroid iii to be included is that the maximum score across all query tokens (S) must be greater than or equal to the threshold. This helps to reduce the computational load by focusing only on the most relevant centroids and eliminating unnecessary computations for low-scoring centroids.

Scoring:

After generating an initial set of candidate passages, PLAID reconstructs the original embeddings for these passages using residual decompression. This step restores the full passage embeddings from the compressed format used during retrieval.

Once the embeddings are reconstructed, PLAID applies the MaxSim operation to rank the candidates based on their similarity to the query. Here, the term MaxSim refers to the method used to compute the similarity scores between the query and the candidate passages. The decompressed embedding vectors for the final candidate set are denoted by D.

Finally, the scores for each query-passage pair are calculated based on Equation 1. This ensures that the final ranking is based on fully reconstructed passage representations, allowing for a high-quality retrieval of relevant passages.

Latency :

References:

  1. ColPali: Efficient Document Retrieval with Vision Language Models
  2. PLAID: An Efficient Engine for Late Interaction Retrieval
  3. ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction

--

--

Neeraj Kumar
Neeraj Kumar

Written by Neeraj Kumar

Staff ML Scientist and PHD @ IIT Delhi, B-Tech @ IIT Kharagpur Connect on Topmate for educational consulting, mock interviews - https://topmate.io/neeraj_kumar

No responses yet