Reminder of the algorithm:
Actual FA2 pseudocode (with more comments on loading between SRAM and HBM)
- This is the code for a single head, within a single sequence within a batch.
What can be parallelized?
-
Seeing this code, itβs clear that block of size of the output is independent of the other blocks.
- For each of these blocks , we will need to load a single query block , and all the blocks
- We want to keep the query block in SRAM for the whole computation
-
We can easily parallelize over all the query blocks
-
How many programs/kernels do we need to launch then?
- We will need to launch `batch_size x num_heads x (seq_len // block_size_q)
grid = lambda args: (
triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]),
BATCH_SIZE* NUM_HEADS,
1
)
Accessing blocks and other optimizations
- Using
tl.make_block_ptr
andtl.advance
abstracts a part of the stride computations, and allows the compiler to optimize memory loads on H100 potentially tl.multiple_of(x, N)
is a compiler hint: youβre asserting that every element inx
is an exact multiple ofN
(whereN
is a compile-time constant).- Triton doesnβt touch the dataβit just records the fact in its IR so that the later optimisation passes and backend code generator can be more aggressive. e.g. pick the aligned variant of HW instructions
- Practical evidence: in the grouped-GEMM tutorial, adding the hint can make the code 3x faster.
Separating the left-diagonal and diagonal
- Because of block-wise processing, we need to be careful of attending only to the causal K blocks
- We donβt merge the for-loop of the left-diagonal, with the diagonal processing to optimize the pipelining Triton does
Causal Attention Lower Triangular Structure Illustration
=======================================================
In causal attention, each token can only attend to itself and previous tokens.
This creates a lower triangular structure in the attention matrix.
For a sequence of length 8, the attention matrix looks like this:
Kβ Kβ Kβ Kβ Kβ Kβ
Kβ Kβ
Qβ β - - - - - - - (Qβ can only attend to Kβ)
Qβ β β - - - - - - (Qβ can attend to Kβ, Kβ)
Qβ β β β - - - - - (Qβ can attend to Kβ, Kβ, Kβ)
Qβ β β β β - - - - (Qβ can attend to Kβ, Kβ, Kβ, Kβ)
Qβ β β β β β - - - (Qβ can attend to Kβ, Kβ, Kβ, Kβ, Kβ)
Qβ
β β β β β β - - (Qβ
can attend to Kβ, Kβ, Kβ, Kβ, Kβ, Kβ
)
Qβ β β β β β β β - (Qβ can attend to Kβ, Kβ, Kβ, Kβ, Kβ, Kβ
, Kβ)
Qβ β β β β β β β β (Qβ can attend to all Kβ-Kβ)
Where:
- β = allowed attention (masked to 0, not -inf)
- - = masked attention (set to -inf)
-
Block-wise Processing:
====================
When processing in blocks with different sizes (e.g., BLOCK_SIZE_Q = 2, BLOCK_SIZE_KV = 3):
Block 0 (Qβ,Qβ): Must load 1 block: Kβ,Kβ,Kβ (left of diagonal)
Block 1 (Qβ,Qβ): Must load 2 blocks: Kβ,Kβ,Kβ,Kβ,Kβ,Kβ
(left of diagonal)
Block 2 (Qβ,Qβ
): Must load 2 blocks: Kβ,Kβ,Kβ,Kβ,Kβ,Kβ
,Kβ (left of diagonal)
Block 3 (Qβ,Qβ): Must load 3 blocks: Kβ,Kβ,Kβ,Kβ,Kβ,Kβ
,Kβ,Kβ (left of diagonal)
The diagonal edge case occurs when a K block spans across the causal boundary.
For example, in Block 1 (Qβ,Qβ), the K block [Kβ,Kβ,Kβ
] contains:
- Kβ: allowed (Qβ can attend to Kβ)
- Kβ,Kβ
: masked (Qβ,Qβ cannot attend to future tokens)
This requires STAGE 2 processing to handle the mixed masked/unmasked keys within a single K block.
Log-sum-exp βtrickβ
-
For the attention score matrix
A= softmax(QK^T) = softmax(X) ## shape (T,T)
, each row is defined as -
Letβs define
-
Assume that in the forward pass, if we store for the backward pass (the
logsumexp
variable) -
Then, in the backward pass, we will recompute , and we can recover by computing