ME 759 · High-Performance Computing

FlashAttention
in Raw CUDA

From-scratch CUDA: IO-aware FlashAttention on a Tesla T4, with online softmax, shared-memory tiling, and O(N) HBM traffic to the output instead of materializing the full $N^2$ scores.

Causal Flash vs naive N=8192 ×3.71 pick a different N in the lab to update this

Q
K
V
O
HBM (slow)
Qi
Kj
Vj
SRAM (fast, on-chip)

Why Attention is Memory-Bound

The core mechanism inside modern Large Language Models (like GPT or Llama) is the scaled dot-product attention operation. Before we optimize it, we must first understand the fundamental mathematical operations it requires:
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right)V$$
In this equation, $Q$, $K$, and $V$ are matrices of shape $N \times d$, where $N$ is the sequence length (how many tokens the model is looking at) and $d$ is the head dimension.

The Standard (Naive) Implementation

The standard way to execute this formula on a GPU requires computing it step-by-step and saving the intermediate results to the GPU's main memory, known as High-Bandwidth Memory (HBM). The process looks like this:
The Bottleneck: Notice how many times we had to read and write an $N \times N$ matrix to the GPU's memory. As the sequence length $N$ grows (e.g., passing a long document to the model), the memory footprint balloons quadratically ($\Theta(N^2)$). The GPU's compute cores end up idling, waiting for slow memory transfers to finish. This is known as a memory-bound algorithm.

The FlashAttention Solution: Tiling and Kernel Fusion

FlashAttention solves this memory bandwidth bottleneck without sacrificing exactness using two core classical hardware optimizations:
🐌

Naive Attention

3 distinct kernel launches. Requires 3 full round-trips to global memory for the massive $N\times N$ matrix.

$\mathcal{O}(N^2)$ HBM Accesses

FlashAttention

Fuses the operations into one kernel. Keeps the $\mathcal{O}(N^2)$ matrix inside fast on-chip SRAM.

$\mathcal{O}(N \cdot d)$ HBM Accesses

Hardware Limits: NVIDIA Tesla T4

To understand why keeping data on-chip is so vital, let's look at the hardware specs of the GPU we are optimizing for. The Tesla T4 Streaming Multiprocessors (SMs) have incredibly fast compute, but we are severely limited by memory bandwidth.
Global Memory (HBM) Bandwidth320 GB/s
On-Chip Shared Memory (SRAM) Capacity64 KB per SM
Compute (FP32) Throughput8.1 TFLOP/s
By keeping our processing chunks (tiles) small enough to fit inside that 64 KB SRAM limit, FlashAttention never has to write the intermediate $N \times N$ matrix to the slow 320 GB/s HBM. It all stays on the ultra-fast chip.

Online Softmax & Tiling

To keep the matrix entirely on-chip, we must split it into smaller blocks (tiles). However, this introduces a major mathematical challenge: the Softmax function.

The Softmax Constraint

Mathematically, Softmax normalizes a set of numbers (a row of matrix scores). In hardware, computers have limited 32-bit floating-point precision, so to prevent numerical overflow, we subtract the maximum value of the row ($m$) before doing the exponential calculations:
$$ \text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j} e^{x_j - m}} $$
The paradox is that to find the global maximum value $m$, we must look at every element in the row simultaneously. Because we are aggressively tiling the matrix to fit into small 64 KB SRAM chunks, we don't have access to the entire row at once! We only see a small piece of it at a time.

The Solution: Online Softmax

The key innovation of FlashAttention is using algebraic tricks to update the Softmax denominator incrementally as new tiles arrive. We keep track of a moving local maximum ($m$) and a local normalization sum ($\ell$) directly within the ultra-fast thread registers.

As each new tile $j$ arrives from Global Memory, we update the local max $m^{(j)}$:
$$m^{(j)}_i = \max\!\bigl(m^{(j-1)}_i,\; \tilde{m}_{ij}\bigr)$$
We then update the local sum denominator $\ell^{(j)}$ and immediately apply it to the output $O^{(j)}$. Notice the scaling factors $\alpha$ and $\beta$:
$$\ell^{(j)}_i = \underbrace{e^{m^{(j-1)}_i - m^{(j)}_i}}_{\alpha}\,\ell^{(j-1)}_i + \underbrace{e^{\tilde{m}_{ij} - m^{(j)}_i}}_{\beta}\,\tilde{\ell}_{ij}$$
$$O^{(j)}_i = \alpha \,O^{(j-1)}_i + \beta \,\tilde{P}_{ij}V_j$$
By scaling the previous output by $\alpha$, we mathematically correct the old numbers as if we had always known the true global maximum $m$. The running state $(m_i, \ell_i, O_i)$ lives strictly inside the GPU registers. This completely eliminates the need to write intermediate results to slow global memory footprint.

Hardware Execution Trace

Observe how a single Query tile ($Q_i$) remains stationary in SRAM while iterating over the Key/Value tiles ($K_j, V_j$). Click Next Step to advance the hardware simulation.
Global Memory (HBM)
High Capacity (16GB) • High Latency (200+ cycles) • Bandwidth: 320 GB/s
Q Matrix
K Matrix
V Matrix
O (Output)
MEMORY BUS IDLE
Streaming Multiprocessor (SM)
Compute and Ultra-Fast Memory
Shared Memory (SRAM)
64KB Banked • Latency: ~30 cycles
Warp & Register File
Thread execution • Latency: ~1 cycle
$m_i = -\infty$, $\ell_i = 0$, $O_i = 0$
Press "Next Step" to start the SM execution trace.
Step 0 / 6

Benchmark laboratory

This pulls from the same CSVs as the write-up. Choose N, flip kernel traces on or off, and use a log-scaled latency axis so the flash_v2 line does not flatten everything else.

Sequence length N
Show on chart

Live metrics at selected N

Naive
Latency
HBM read
HBM write
baseline for speedup column
Flash causal
Latency
HBM read
HBM write
speedup vs naive
Flash v1
Latency
HBM read
HBM write
speedup vs naive
WMMA
Latency
HBM read
HBM write
speedup vs naive
Flash v2 (experiment)
Latency
HBM read
HBM write
speedup vs naive · NCU shows ~2× K/V streaming vs v1

Empirical Benchmark Evaluation

I timed kernels and read back global load/store bytes with NCU. Everything here is CUDA C++ on a Tesla T4 (sm_75) with head size $d = 64$, same as the report.

Execution Latency (ms) vs Sequence Length

Latency Plot

HBM Traffic (Total MB) vs Sequence Length

HBM Plot
What to notice: in the HBM plot, naive traffic blows up with $N$ because it writes the whole score matrix. Flash stays closer to linear because it never parks that matrix in HBM.
📈
At N = 4096: FlashAttention is ~1.8× faster and triggers a ~138× reduction in global memory writes compared to the naive baseline. That lines up with dropping the $O(N^2)$ write to global memory.

Kernels & code

Each row in the lab matches one of these paths. Expand a card for the story; dot colors match the chart. Below that is a side-by-side taste of CUDA: naive softmax hammering HBM vs Flash keeping state in registers. Full sources live on GitHub and in the TeX report.

Writes full $S=QK^\top/\sqrt{d}$, softmax, then $PV$. Correct reference, worst IO.

Three launches mean the big score matrix ping-pongs through HBM. NCU tells that story in gigabytes.

Fixed $(B_r,B_c,d)$ tiles, fused softmax in registers, writes only $O$.

Online softmax rescaling keeps running $(m_i,\ell_i,O_i)$ in the register file while streaming $K_j,V_j$ tiles. This is the core IO win over naive.

Skips future $(K,V)$ tiles and masks on the diagonal. Fastest in my T4 numbers.

Not just masking logits. Skipping whole tiles means you never load K/V you will not use, so HBM reads drop by a lot compared to dense Flash at large $N$.

Smaller $B_r$ sounded clever; it bought extra outer-loop passes.

When $B_r$ is half as wide you sweep the whole $K/V$ stream twice as often. NCU reads land near ~2× Flash v1, and wall clock goes past naive at big $N$. Lesson learned: fusion without the right tile shape still loses.

FP16 Tensor Core matmul for $QK^\top$ blocks with FP32 accumulators on sm_75.

Lowest global reads in the table, competitive latency on the non-causal path. Causal tile skipping is not fused here yet, so FP32 causal Flash still wins on pure ms in these runs.

Reproduce everything.

Clone, cd FinalProject, build flash_attn, run benchmarks/run_bench.sh and run_ncu_profile.sh. The numbers in the lab above come from that same CSV pipeline (data/results/).

Jump to commands ↓

CUDA excerpts

Representative fragments (not whole translation units). Compare memory traffic vs register-resident softmax state.

// Kernel 2: row-wise softmax (data lives in HBM)
// Reads and writes the full N×N matrix S; O(N²) HBM traffic
__global__ void softmax_kernel(float* S, int N) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= N) return;

    float* row_ptr = S + row * N;

    // Pass 1: find row max (numerical stability)
    float mx = -FLT_MAX;
    for (int j = 0; j < N; ++j)
        mx = fmaxf(mx, row_ptr[j]);          // reads N floats from DRAM

    // Pass 2: exponentiate and sum
    float sum = 0.0f;
    for (int j = 0; j < N; ++j) {
        row_ptr[j] = expf(row_ptr[j] - mx);  // read + write N floats to DRAM
        sum += row_ptr[j];
    }

    // Pass 3: normalize
    float inv = 1.0f / sum;
    for (int j = 0; j < N; ++j)
        row_ptr[j] *= inv;                   // read + write N floats to DRAM
}
// Total: ~4·N² float reads + ~2·N² float writes per softmax call
// Online softmax rescale inside FlashAttention
// Runs in registers; the full N×N matrix never exists in HBM
for (int j = 0; j < Tc; ++j) {
    // ... load K_j, V_j tiles into SRAM ...

    // Compute S_row = Q[q_idx] · K_j^T / sqrt(d)  (BC dot products)
    float S_row[BC], m_tilde = -FLT_MAX;
    for (int jj = 0; jj < bc_actual; ++jj) {
        float dot = 0.0f;
        for (int k = 0; k < D; ++k)
            dot += Q_smem[t][k] * K_smem[jj][k];
        S_row[jj] = dot * scale;
        m_tilde   = fmaxf(m_tilde, S_row[jj]);
    }

    // P_tilde = exp(S_row - m_tilde),  l_tilde = sum(P_tilde)
    float l_tilde = 0.0f;
    for (int jj = 0; jj < bc_actual; ++jj) {
        S_row[jj] = expf(S_row[jj] - m_tilde);
        l_tilde  += S_row[jj];
    }

    // Online rescale: no global memory traffic here
    float m_new = fmaxf(m_i, m_tilde);
    float alpha = expf(m_i     - m_new);   // rescale old contribution
    float beta  = expf(m_tilde - m_new);   // rescale new contribution

    for (int k = 0; k < D; ++k) {
        float pv_k = 0.0f;
        for (int jj = 0; jj < bc_actual; ++jj)
            pv_k += S_row[jj] * V_smem[jj][k];
        O_acc[k] = alpha * O_acc[k] + beta * pv_k; // update in registers
    }
    m_i = m_new;
    l_i = alpha * l_i + beta * l_tilde;
}
// Final write: O[q_idx] = O_acc / l_i  ← ONE write to HBM per row

Build, bench, profile

git clone https://github.com/batra98/me759-flashattention.git
cd me759-flashattention/FinalProject && mkdir -p build && cd build
cmake .. && make -j"$(nproc)"
./flash_attn --mode correctness --target flash --seq_len 2048
bash ../benchmarks/run_bench.sh ./flash_attn ../data/results
sudo bash ../benchmarks/run_ncu_profile.sh ./flash_attn ../data/results/hbm_traffic.csv
cd .. && python3 python/plot_results.py