Learn the single most important optimization for memory-bound workloads: fusing many operations into one kernel. Count the DRAM round-trips that make unfused code slow, implement softmax as a one-pass row-wise reduction, make it numerically stable, and measure the fused-vs-unfused speedup.
Softmax appears everywhere — attention, classification heads, mixture weights — and it's the perfect vehicle for learning kernel fusion, the optimization that defines practical Triton work.
A naive softmax in PyTorch runs as a chain of separate kernels: subtract the max, exponentiate, sum, divide. Each kernel reads its input from DRAM and writes its output back to DRAM, so the data makes the slow round-trip multiple times. For a memory-bound operation, that traffic — not the arithmetic — is the entire runtime.
A fused kernel loads each row from DRAM exactly once, performs the whole computation (max → exp → sum → divide) while the row sits in fast on-chip SRAM, and writes the result once. The math is identical; the memory traffic collapses by a factor equal to the number of fused passes. That's the whole game.
Along the way you'll implement softmax as a row-wise reduction (each program owns one row), and you'll make it numerically stable by subtracting the row max before exponentiating — without which large logits overflow to infinity. By the end you'll understand why fusion is the first thing a Triton programmer reaches for.
This chapter covers:
Click any topic to jump in
Unfused chains move the data to DRAM and back per op; traffic = runtime for memory-bound work.
One load, all the work in SRAM, one store — intermediates never hit DRAM.
One program per row; tl.max/tl.sum reduce on-chip. Generalizes to layer norm & attention.
Subtract the row max before exp — shift-invariant, prevents overflow, essentially free.
Measure the win: a fused softmax nears the bandwidth roofline by moving the minimum bytes.
Every time data leaves the chip and comes back, you pay the slow DRAM latency and consume precious bandwidth. For memory-bound ops, the number of these round-trips is a direct proxy for runtime.
Consider softmax done as separate ops over an matrix: (1) read X, compute row max, write nothing useful but PyTorch may materialize intermediates; (2) read X again, subtract max, exp, write a temporary; (3) read the temporary, sum per row; (4) read again, divide, write output. Each pass moves values to or from DRAM. The fused version moves the matrix in once and out once — roughly values total, versus the – of the unfused chain.
For a memory-bound op (and softmax, being low arithmetic-intensity, is exactly that), runtime . Cutting the bytes by 3–4× cuts the runtime by 3–4×. The arithmetic (a few exps and adds per element) is essentially free by comparison — the data movement is the cost.
If an operation is a chain of memory-bound passes over elements, unfused traffic is (read+write per pass) while fused traffic is . The speedup ceiling for a bandwidth-bound op is therefore — it grows linearly with how many passes you collapse into one kernel.
An unfused softmax over a 4096×4096 float32 matrix makes 4 passes that each read+write the matrix. How many bytes hit DRAM, and how much does fusing to a single load+store save?