Logsumexp trick and Flash attention- Part 2

Neeraj Kumar
5 min readSep 5, 2024

--

GPU memory hierarchy consists of various types of memory that differ in size and speed, where smaller memory is faster. For instance, the A100 GPU features 40–80GB of high-bandwidth memory (HBM) with a bandwidth of 1.5–2.0TB/s, alongside 192KB of on-chip SRAM per each of its 108 streaming multiprocessors, with an estimated bandwidth of around 19TB/s . Although the on-chip SRAM is significantly faster than HBM, it is much smaller in capacity.

GPUs have a massive number of threads to execute an operation (called a kernel). Each kernel loads inputs from HBM to registers and SRAM, computes, then writes outputs to HBM.

Flash Attention objective:

Attention algorithms are made IO-aware, meaning they are designed to efficiently manage reads and writes across different memory tiers, such as fast on-chip GPU SRAM and the relatively slower high-bandwidth memory (HBM) . On modern GPUs, computation speed has surpassed memory speed , leading to memory access bottlenecks in most Transformer operations. FlashAttention, a new algorithm, reduces memory accesses while computing exact attention. The primary objective is to minimize the need to read and write the attention matrix to and from HBM.

Standard approach:

Load Q and K:
Load the query matrix ( Q ) and the key matrix ( K ) from HBM.
Compute the scores matrix ( S = QK^T ) and write ( S ) back to HBM.
Compute P:
Read the scores matrix ( S ) from HBM.
Apply the softmax function to ( S ), resulting in the probabilities matrix ( P ).
Write ( P ) back to HBM.
Load P and V:
Load the probabilities matrix ( P ) and the value matrix ( V ) from HBM.
Compute the output matrix ( O = PV ) and write ( O ) back to HBM.

Limitation:

High Memory Usage: Storing large intermediate matrices ( S ) and ( P ) in HBM.

Numerical Stability: Potential for numerical instability when exponentiating large or small values in the softmax operation.

Performance Bottlenecks: Multiple reads and writes to and from HBM, leading to slower performance due to memory access latency.

Flash Attention Approach

The key idea is to divide the inputs Q, K, and V into blocks, load them from slower HBM into faster SRAM, and then compute the attention output for each block. ttention is computed in blocks. Since Softmax couples the columns of K, we break down the large Softmax operation using scaling. To ensure numerical stability, logsumexp trcik has been used.

Flash Attention Algorithm:

Input Matrices:
Q (Query): Shape (N * d)
K^T (Key Transposed): Shape (d * N)
V (Value): Shape (N * d)

Copy Block to SRAM:
Small blocks of the matrices are copied to the fast SRAM from HBM. This is done to leverage the high bandwidth and low latency of SRAM for computations.

Compute Block on SRAM:

The attention scores (QK^T) are computed in SRAM. This involves matrix multiplication, which benefits from the high-speed access of SRAM.
The softmax operation is applied, resulting in the matrix (sm(QK^T)).

Inner and Outer Loops:

Outer Loop: Iterates over blocks of the Key and Value matrices, copying them to SRAM and performing computations.

Inner Loop: Iterates over blocks of the Query matrix, copying them to SRAM for block-wise computation.

Output to HBM:

The resulting matrix (sm(QK^T)V) is computed block-by-block in SRAM and then written back to HBM.

Example

Sequence length ( N = 4 )
Head dimension ( d = 2 )
SRAM size ( M = 8 )
Block sizes ( B_c = 2 ), ( B_r = 2 )

Q = [[1, 0],
[0, 1],
[1, 1],
[0, 0]]

K = [[1, 0],
[1, 1],
[0, 1],
[1, -1]]

V = [[1, 2],
[2, 3],
[3, 4],
[4, 5]]

Execution Steps

  1. Set Block Sizes:
    ( B_c = 2 ), ( B_r = 2 )
  2. Initialization:
    Initialize ( O = [[0, 0], [0, 0], [0, 0], [0, 0]] )
    Initialize ( \ell = [0, 0, 0, 0] )
    Initialize ( m = [-inf, -inf, -inf, -inf] )
  3. Divide Matrices into Blocks:
    Divide ( Q ) into two blocks ( Q_1 ) and ( Q_2 ):
    Q_1 = [[1, 0], [0, 1]]
    Q_2 = [[1, 1], [0, 0]]
    Divide ( K ), ( V ), and ( O ) into two blocks each:
    K_1 = [[1, 0], [1, 1]]
    K_2 = [[0, 1], [1, -1]]
    V_1 = [[1, 2], [2, 3]]
    V_2 = [[3, 4], [4, 5]]
    O_1 = [[0, 0], [0, 0]]
    O_2 = [[0, 0], [0, 0]]
  4. Outer Loop:
    For ( j = 1 ):
    Load ( K_1 ) and ( V_1 ) into on-chip SRAM.

Inner Loop:
For ( i = 1 ):
Load ( Q_1 ), ( O_1 ), ( ell_1 ), and ( m_1 ) into on-chip SRAM.
Compute ( S_{11} = Q_1 K_1^T ):
S_{11} = [[1, 1], [0, 1]]
Compute ( m_{11} ) and ( P_{11} ):
m_{11} = [1, 1]
P_{11} = exp([[0, 0], [-1, 0]])
Update ( m_1 ) and ( ell_1 ):
m_1 = [1, 1]
ell_1 = [1, 1]
Compute the new output block ( O_1 ) and write it back to HBM:
O_1 = [[1, 2], [2, 3]]

  1. Continue for Next Block:

Repeat similar steps for remaining blocks ( j = 2 ) and ( i = 2 ).

References:
1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — https://arxiv.org/pdf/2205.14135

--

--

Neeraj Kumar
Neeraj Kumar

Written by Neeraj Kumar

Staff ML Scientist and PHD @ IIT Delhi, B-Tech @ IIT Kharagpur Connect on Topmate for educational consulting, mock interviews - https://topmate.io/neeraj_kumar

No responses yet