Logsumexp trick and Flash attention- Part 1

Neeraj Kumar
4 min readSep 5, 2024

--

Softmax Concept:

Input Sequences (Q, K, V):
Q: Query matrix, which is a table of numbers with ( N ) rows and ( d ) columns.
K: Key matrix, also a table of numbers with ( N ) rows and ( d ) columns.
V: Value matrix, another table of numbers with ( N ) rows and ( d ) columns.

Here, ( N ) is the length of the sequence (number of tokens), and ( d ) is the size of each token’s embedding.

Attention Output (O):
The output matrix ( O ) also has ( N ) rows and ( d ) columns.
Intermediate Matrices:
S: Matrix of attention scores with ( N ) rows and ( N ) columns.
P: Matrix of softmax-normalized attention probabilities, also with ( N ) rows and ( N ) columns.

Computation Steps
The attention mechanism computation involves three main steps:
Compute Scores (S):
Multiply the query matrix ( Q ) with the transpose of the key matrix ( K ). This gives us the scores matrix ( S ).

Apply Softmax (P):
Apply the softmax function to each row of the scores matrix ( S ) to get the attention probabilities matrix ( P ).

Compute Output (O):
Multiply the attention probabilities matrix ( P ) with the value matrix ( V ) to get the final output matrix ( O ).

Example

Let’s take an example with small matrices for simplicity:
Suppose we have a sequence length ( N = 3 ) and embedding dimension ( d = 2 ).
Let’s define the matrices ( Q ), ( K ), and ( V ) as follows: for simplicity of calculation we have no considered the sqrt(d) in the denominator of q.K^T.

Q = [ [1, 0],
[0, 1],
[1, 1] ]

K = [ [1, 0],
[1, 1],
[0, 1] ]

V = [ [1, 2],
[2, 3],
[3, 4] ]

Compute Scores (S):
Multiply ( Q ) with the transpose of ( K ) to get the scores matrix ( S ):

S = [ [1, 1, 0],
[0, 1, 1],
[1, 2, 1] ]

2. Apply Softmax (P):

Apply the softmax function to each row of ( S ) to get the probabilities matrix ( P ). For simplicity, let’s assume the softmax results are:

P = [ [0.5, 0.5, 0],
[0, 0.5, 0.5],
[0.2, 0.5, 0.3] ]

3. Compute Output (O):

Multiply ( P ) with ( V ) to get the final output matrix ( O ):

O = [ [1.5, 2.5],
[2.5, 3.5],
[2.3, 3.5] ]

Issue : exponentiating might result in under- or overflow .

Logsumexp trick

The log-sum-exp (LSE) trick is a numerical technique used to handle calculations involving logarithms of sums of exponentials. The LSE trick helps to prevent numerical overflow or underflow and improves stability.

Background

When working with log probabilities or log likelihoods, directly
exponentiating large or small values can lead to numerical instability. For example, if we have log probabilities \(x_i = \log(p_i)\) and we need to normalize them to obtain actual probabilities \(p_i\), we might naively compute:

This can lead to issues because \(x_i\) can be very large or very small, causing overflow or underflow when exponentiated.

Log-Sum-Exp (LSE) Operation

This operation helps mitigate numerical issues by working in the log domain.

Rewriting the Normalization

To understand why the LSE trick works, let’s rewrite the naive normalization formula using the LSE operation.

We can rewrite this as:

Utility of the LSE Trick

To see why this trick is useful, consider the derivation that involves shifting the values in the exponent:

By choosing ( c ) as the maximum value in the set \(\{x_1, \ldots, x_N\}\):

we ensure that the largest exponentiated term is \(\exp(0) = 1\), which avoids overflow.

Summary

- The LSE operation helps to normalize log probabilities in a numerically stable way.
- By subtracting the maximum value in the set from each \( x_n \) before exponentiating, we prevent numerical overflow or underflow.
- This approach allows us to work with large or small log values safely and accurately.

In essence, the log-sum-exp trick is a clever way to handle sums of exponentials in the log domain, ensuring numerical stability and preventing overflow or underflow issues.

Second part is in the below link:
https://neerajku.medium.com/logsumexp-trick-and-flash-attention-part-2-9e2d55cf1610

--

--

Neeraj Kumar
Neeraj Kumar

Written by Neeraj Kumar

Staff ML Scientist and PHD @ IIT Delhi, B-Tech @ IIT Kharagpur Connect on Topmate for educational consulting, mock interviews - https://topmate.io/neeraj_kumar

No responses yet