Triton Kernels

gpus
Published

September 8, 2025

GPU Programming: From Memory Bottlenecks to Triton

The Bandwidth Wall

Modern GPUs have massive compute but limited memory bandwidth. The key metric is arithmetic intensity (AI):

AI = FLOPs / Bytes Transferred

To be compute-bound rather than memory-bound, you need high arithmetic intensity. On a B200: - Compute: ~40 TFLOPS (FP32)
- Memory bandwidth: ~8 TB/s - Breakeven AI: ~5 FLOPs/byte

Consider different operations:

# Vector addition: y = a + b
# Bytes: read 2*N*4, write N*4 = 12N bytes
# FLOPs: N additions
# AI = N / 12N = 0.083 FLOPs/byte → Memory bound!

# Matrix multiply: C[M,N] = A[M,K] @ B[K,N]
# Bytes: (M*K + K*N + M*N) * 4
# FLOPs: 2*M*N*K (multiply + add)
# AI = 2*M*N*K / ((M*K + K*N + M*N) * 4)
# For M=N=K=4096: AI ≈ 683 FLOPs/byte → Compute bound!

Breaking the Bandwidth Wall: Flash Attention

Standard attention has a fundamental problem:

# Standard attention: O = softmax(QK^T)V
# Must materialize S = QK^T of size [N, N]
# Memory accesses: O(N²) reads/writes to HBM
# Compute: O(N²) 
# AI: O(1) → Memory bound!

Flash Attention’s key insight: tiling changes the memory access pattern.

# Flash Attention with tile size T
# Process attention in T×T blocks, never materializing full [N,N]
# Memory accesses: O(N²/T) to HBM  
# Compute: still O(N²)
# AI: O(T) → Can be compute bound!

The magic is keeping tiles in the memory hierarchy: - HBM (main memory): ~8 TB/s - L2 cache: ~100 TB/s - SRAM (shared memory): ~150 TB/s
- Registers: Even faster

By computing in tiles that fit in SRAM, Flash Attention: 1. Loads each KV block from HBM only once per query block instead of N times 2. Performs all computations in fast SRAM
3. Only writes final output to HBM

The tiling transforms how we walk through memory—instead of repeatedly traversing the entire KV cache, we load contiguous blocks once and reuse them maximally.

Memory is Always 1D

Memory is a flat 1D array. Your tensor[M, N] doesn’t exist—it’s just multi-dimensional indexing over M*N consecutive bytes.

# What you write
A = torch.randn(1024, 768)  # A[i,j]

# What actually exists in memory (row-major)
memory = [row0_col0, row0_col1, ..., row0_col767,
          row1_col0, row1_col1, ..., row1_col767, ...]
# A[i,j] → memory[i*768 + j]

Coalescing and performance depend on how your multi-dimensional indexing walks that 1D memory. When 32 threads in a warp access memory, the hardware can only coalesce these into a single transaction if your indexing pattern makes them access consecutive addresses:

# Row-major traversal: threads access A[i, 0:32]
# Memory walk: [i*768+0, i*768+1, ..., i*768+31]
# Consecutive addresses → 1 coalesced transaction

# Column-major traversal: threads access A[0:32, j]  
# Memory walk: [0*768+j, 1*768+j, ..., 31*768+j]
# Addresses 768 elements apart → 32 separate transactions

This is why choosing the right stride matters:

x = torch.randn(4096, 4096, device='cuda')

# Row-wise sum: indexing walks memory sequentially
x.sum(dim=1)  # Fast - stride of 1 through memory

# After transpose: logical rows are physical columns
y = x.T  # No data movement, just changes stride from (4096, 1) to (1, 4096)
y.sum(dim=1)  # ~30x slower - stride of 4096 through memory

# .contiguous() reorganizes memory to match logical view
z = y.contiguous()  # Copies data so logical indexing = sequential walk
z.sum(dim=1)  # Fast again

Triton: Think in Blocks, Not Threads

CUDA forces you to manually manage how threads walk through memory:

// CUDA: You explicitly map threads to memory
__global__ void add_kernel(float* x, float* y, int n) {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;  // Your indexing
    if (tid < n) {
        y[tid] = x[tid] + 1.0f;  // You ensure coalescing
    }
}
// Launch: add_kernel<<<num_blocks, threads_per_block>>>(x, y, n);

Triton abstracts this—you think in vectors, it handles the memory walk:

@triton.jit
def add_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr):
    # Program ID - which block are we?
    pid = tl.program_id(0)
    
    # This program handles elements [pid*BLOCK_SIZE : (pid+1)*BLOCK_SIZE]
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n
    
    # Vector operations - Triton ensures coalesced walk
    x = tl.load(x_ptr + offsets, mask=mask)
    y = x + 1.0
    tl.store(y_ptr + offsets, y, mask=mask)

# Launch: each program processes BLOCK_SIZE elements
grid = (triton.cdiv(n, BLOCK_SIZE),)
add_kernel[grid](x, y, n, BLOCK_SIZE=1024)

What Triton abstracts: - Thread-to-memory mapping: You specify vector operations, compiler assigns optimal thread walks - Coalescing patterns: tl.load/store automatically generates coalesced access patterns - Shared memory management: Compiler handles SRAM allocation and synchronization - Register allocation: Automatic register reuse for intermediates

The notation kernel[grid](args) launches the kernel on a grid of programs (blocks), where each program operates on vectors of data.

Matmul: Putting It All Together

Here’s how tiled matmul achieves high arithmetic intensity through careful memory traversal.

The Basics

For matrix multiplication C = A @ B where: - A has shape [M, K] - B has shape [K, N]
- C has shape [M, N]

We launch (M // BLOCK_M) * (N // BLOCK_N) kernels. Each kernel computes one tile:

Matrix C divided into tiles:          Each kernel computes one tile:
[C00 C01 C02 C03]                    
[C10 C11 C12 C13]                    C11 = A[row1] @ B[col1]
[C20 C21 C22 C23]                         ↓           ↓
                                     A[128:256,:] @ B[:,128:256]

L2 Cache Optimization with GROUP_SIZE_M

The L2 cache optimization happens at the kernel scheduling level—each kernel computes one block, and GROUP_SIZE_M controls the order in which kernels execute to maximize L2 reuse between kernels.

How GROUP_SIZE_M=2 Reorganizes Execution:

Standard Row-Major Order         Grouped Column-Major Order
(process entire row 0 first)     (process 2 rows at a time)

Block IDs:                       Execution Order:
[0]  [1]  [2]  [3]  [4]  [5]    [0]  [2]  [4]  [6]  [8]  [10]
[6]  [7]  [8]  [9]  [10] [11]   [1]  [3]  [5]  [7]  [9]  [11]
[12] [13] [14] [15] [16] [17] → [12] [14] [16] [18] [20] [22]
[18] [19] [20] [21] [22] [23]   [13] [15] [17] [19] [21] [23]
[24] [25] [26] [27] [28] [29]   [24] [26] [28] [30] [32] [34]
[30] [31] [32] [33] [34] [35]   [25] [27] [29] [31] [33] [35]

Why this matters for cache:

Standard row-major:              Grouped (GROUP_SIZE_M=2):
Block 0: A[row0] × B[col0]      Block 0: A[row0] × B[col0]
Block 1: A[row0] × B[col1]      Block 1: A[row1] × B[col0] ← B[col0] reused!
Block 2: A[row0] × B[col2]      Block 2: A[row0] × B[col1]
Block 3: A[row0] × B[col3]      Block 3: A[row1] × B[col1] ← B[col1] reused!
Block 4: A[row0] × B[col4]      ...
Block 5: A[row0] × B[col5]      
Block 6: A[row1] × B[col0]      
         ↑ B[col0] likely evicted!

GROUP_SIZE_M=2 keeps both A[row0] and A[row1] in L2 cache across kernel launches. Result: 10-15% performance improvement.

How the Grouped Order Maps to Pointers

With grouped column-major ordering, we remap pid to get better cache reuse:

Example: pid=9, M=768, N=768, BLOCK_M=128, BLOCK_N=128, GROUP_SIZE_M=2

num_pid_m = 768/128 = 6 rows of blocks
num_pid_n = 768/128 = 6 cols of blocks  
num_pid_in_group = 2 * 6 = 12 (process 2 rows × 6 cols before moving on)

group_id = 9 // 12 = 0 (still in first group)
first_pid_m = 0 * 2 = 0 (first group starts at row 0)

pid_m = 0 + ((9 % 12) % 2) = 0 + (9 % 2) = 1
pid_n = (9 % 12) // 2 = 9 // 2 = 4

So pid=9 computes C[128:256, 512:640]
          (row block 1, col block 4)

Building Pointers and K-Chunking

To compute C[128:256, 512:640], we need: - From A: A[128:256, 0:K] (128 rows, all K columns) - From B: B[0:K, 512:640] (all K rows, 128 columns)

But K=4096 won’t fit in SRAM! We process in chunks:

Full matrix multiply:
C[128:256, 512:640] = A[128:256, 0:4096] @ B[0:4096, 512:640]

Split into K chunks:
Iteration 0: A[128:256, 0:64]    @ B[0:64, 512:640]    → accumulate
Iteration 1: A[128:256, 64:128]  @ B[64:128, 512:640]  → accumulate  
Iteration 2: A[128:256, 128:192] @ B[128:192, 512:640] → accumulate
... (64 iterations total for K=4096, BLOCK_K=64)

How strides create the pointer grid:

For A stored row-major with shape [M, K]:

Memory is 1D: A[i,j] maps to memory[i*K + j]

Example: Loading A[128:256, 192:256] where K=4096

Row 128, cols 192-255: memory[128*4096 + 192] to memory[128*4096 + 255]
                        = positions 524480 to 524543 (contiguous!)
                        
Row 129, cols 192-255: memory[129*4096 + 192] to memory[129*4096 + 255]  
                        = positions 528576 to 528639 (contiguous!)
                        
The pattern:
- Each row is contiguous in memory (64 consecutive elements)
- Rows are K=4096 elements apart in memory
- stride_am=4096 (jump to next row), stride_ak=1 (next column)

Each K iteration loads tiles that fit in SRAM (128×64 = 8KB), accumulates the partial product, then moves to the next K chunk. The grouped ordering ensures A blocks stay hot in L2 across multiple B column blocks.

The key insight: by controlling how our multi-dimensional indexing walks through 1D memory—using tiles, grouped execution, and well-chosen strides—we transform memory-bound operations into compute-bound ones.