-
Given the release of our LFM2 models, I thought it would be a good time to discuss the inference efficiency of different architecture, purely focusing on inference and not on quality. For a good discussion on the representation ability of different archs, the latest blog from Albert Gu
-
reducing FLOPs is not a universal cure to get a faster arch.
-
Abstracting away details of large-scale serving, for a model to be fast & efficient, you need to think about::
-
You need to deal with
- compute-bound prefill
- memory-bound decode
- implementation on the hardware.
- cache size
Attention is inherently bottlenecked by its KV cache during decode, but it’s very hardware friendly, and has a very fast prefill. However, it suffers from a linear cache size.
Gated convolutions
- gated convolutions
Comparing the FLOPs
-
As a reminder, we will compute the FLOPs per token
-
A linear layer with
N = in_features*out_features
parameters requires6N
FLOPs per token seen (2N
for forward and4N
for backward).- This is because each matmul performs one multiplication and one addition for each parameter (Imagine a matrix-vector product where the vector is the token is a column vector of size
in_features
).The backward pass includes two matmuls for each one in the forward pass (Appendix B of PaLM paper)
- This is because each matmul performs one multiplication and one addition for each parameter (Imagine a matrix-vector product where the vector is the token is a column vector of size
-
For the
-
Both operators share the same input projections and output projections, which account for