CuTe GEMM Kernel Design — The Big Picture

A roadmap for designing a high-performance GEMM kernel using CuTe DSL

1. The High-Level Flow (30-Second Overview)

You want: C[M,N] = A[M,K] × B[N,K] (matrix multiply) You write: Host (@cute.jit): 1. Pick an MMA atom ← what hardware instruction? 2. Pick atoms_layout ← how many warps? 3. Pick permutation ← how to tile across the CTA? 4. Pick tile sizes (BM,BN,BK) ← how big is each thread block's work? 5. Launch kernel Kernel (@cute.kernel): A. local_tile → extract this CTA's data from global memory B. partition_C → figure out this thread's output elements C. K-loop → load data, call cute.gemm, accumulate D. epilogue → store results back to global memory CuTe handles: All index math, thread→element mapping, address computation. You just describe "what" — CuTe derives "how."

2. The Layer Diagram — What's Fixed, What's Chosen, What's Derived

FIXED BY HARDWARE

Layer 1: Hardware (NVIDIA ISA)

The MMA instruction's properties — baked in silicon, can't change.

Instructionfma.rn.f32 or mma.sync.m16n8k16 or wgmma.m64n256k16
Atom shape (M,N,K)(1,1,1) or (16,8,16) or (64,256,16)
Threads per atom1 or 32 (warp) or 128 (warpgroup)
Values per thread (V)1 or 4 or more
TV_layout stridesThe exact thread→element wiring
Data typesfp32, fp16→fp32, fp16→fp16, etc.
↓ encoded as
FIXED (wrapper)

Layer 2: Atom (CuTe Encoding)

CuTe's layout-based encoding of the hardware truth. One-to-one mapping.

TV_layout_C((4,8),(2,2)):((16,1),(8,64)) for m16n8k16
TV_layout_AHow A-elements map to threads (ISA-defined)
TV_layout_BHow B-elements map to threads (ISA-defined)
↓ + ↓
GIVEN BY PROBLEM

Problem Inputs

What your application hands you. Not free choices.

Matrix dimensionsM, N, K (e.g., 4096×4096×4096)
Data layoutsA: (M,K):(1,M), B: (N,K):(1,N), C: (M,N):(N,1)
Data typesfp16 inputs, fp32 accumulator, fp16 output
↓ your engineering decisions ↓
YOUR DESIGN CHOICES

Layer 3: TiledMMA + Kernel Config

The tuning knobs you control. This is where kernel optimization lives.

atoms_layoutHow many atoms (warps) per CTA — e.g., (2,2,1) = 4 warps
permutationCoverage expansion + element pattern — e.g., (32,32,16)
BM, BN, BKCTA tile size — e.g., 128×128×32
NUM_STAGESPipeline depth — e.g., 3 async stages
Copy atomsWhich load/store instructions — cp.async, ldmatrix, etc.
SMEM layoutSwizzle pattern for bank-conflict-free access
↓ CuTe derives everything below ↓
AUTO-DERIVED

Derived Quantities (You Don't Compute These)

Thread→element mappingFrom atom TV + atoms_layout + permutation
partition_C/A/B shapes(V, M_rest, N_rest) — each thread's view
Registers per threadV × M_rest × N_rest (accumulator count)
Grid dimensionsceil(M/BM) × ceil(N/BN)
K-iterationsK / BK (outer), BK / atom_K (inner)
SMEM size(BM×BK + BN×BK) × dtype_size × NUM_STAGES
Copy tilingHow threads cooperate on G→S and S→R transfers
↓ produces ↓
THE KERNEL

Running Kernel

A compiled PTX kernel that executes the GEMM with all the above baked in as constants.

3. The Kernel Structure (What You Actually Write)

━━━ HOST SIDE (@cute.jit) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ┌─────────────────────────────────────────────────────┐ │ 1. Create MMA atom (pick the hardware instruction) │ │ 2. Create TiledMMA (atom + layout + permutation) │ │ 3. Create copy atoms (G→S, S→R copy instructions) │ │ 4. Define SMEM layouts (with swizzle) │ │ 5. Compute grid = ceil_div(problem, tile) │ │ 6. Launch kernel │ └─────────────────────────────────────────────────────┘ ━━━ KERNEL SIDE (@cute.kernel) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ┌─────────────────────────────────────────────────────┐ │ A. TILE — Extract this CTA's work from global memory │ │ gA = local_tile(mA, tiler, coord, proj) │ │ gB = local_tile(mB, tiler, coord, proj) │ │ gC = local_tile(mC, tiler, coord, proj) │ ├─────────────────────────────────────────────────────┤ │ B. PARTITION — What does THIS thread own? │ │ thr_mma = tiled_mma.get_slice(tidx) │ │ tCgC = thr_mma.partition_C(gC) │ │ tCrC = make_fragment_C(tCgC) → accumulators │ │ tCrC.fill(0) │ ├─────────────────────────────────────────────────────┤ │ C. K-LOOP — The main compute loop │ │ for k_tile in range(num_k_tiles): │ │ [Load A,B: GMEM → SMEM → Registers] │ │ [Compute: cute.gemm(tiled_mma, ...)] │ │ [Pipeline: commit/wait for next tile] │ ├─────────────────────────────────────────────────────┤ │ D. EPILOGUE — Store results │ │ cute.copy(atom, tCrC, tCgC) → registers→GMEM │ └─────────────────────────────────────────────────────┘

4. The Design Decision Tree

When designing a GEMM kernel, decisions flow in this order:

1

Choose the Atom (Which Instruction?)

Based on your GPU architecture and data types:

GPUData TypeAtomShapeThreads
Anyfp32MmaUniversalOp(1,1,1)1
SM80 (Ampere)fp16MmaF16BF16Op(16,8,16)32
SM90 (Hopper)fp16GMMA(64,N,16)128
SM100 (Blackwell)fp16/fp4UMMA(varies)128

This is usually obvious from your target GPU and data type.

2

Choose atoms_layout (How Many Warps?)

Decide how many atoms work together in one CTA:

NUM_THREADS = product(atoms_layout) × threads_per_atom (2, 2, 1) × 32 = 128 threads (4 warps) ← step 4 (4, 1, 1) × 32 = 128 threads (4 warps) ← alternative: all in M (1, 4, 1) × 32 = 128 threads (4 warps) ← alternative: all in N Tradeoffs: More atoms in M → better A-data reuse (threads share B-rows) More atoms in N → better B-data reuse (threads share A-rows) More total atoms → more threads → potentially better latency hiding
3

Choose Tile Size (BM, BN, BK)

How much work per CTA. The central performance knob:

Constraints: SMEM size: (BM×BK + BN×BK) × sizeof(dtype) × NUM_STAGES ≤ available SMEM Registers: V × (BM/perm_M) × (BN/perm_N) × sizeof(acc) ≤ max registers Divisibility: BM divisible by atom_M × atoms_M, same for BN, BK Typical choices (Ampere, fp16): 128×128×32 (balanced, step 4) 128×256×32 (more N, fewer blocks, higher register pressure) 256×128×32 (more M, good for tall matrices)
4

Choose Permutation (How to Tile Atoms Across the CTA Tile)

Bridge between native atom coverage and full tile:

SIMT (scalar FMA)

Permutation = Layout

Choose R (group size) for coalescing:

R = 4 for fp32 (128-bit stores)
R = 8 for fp16 (128-bit stores)
permutation_M = (F, R):(R, 1)

Tensor Core (mma.sync)

Permutation = Integer sizes

Choose total coverage per iteration:

perm_M = atoms_M × atom_M (= native)
perm_N = atoms_N × atom_N × expand
perm_K = atoms_K × atom_K

expand ≥ 1 (more = more regs)
5

Choose Data Movement Strategy

How data flows: GMEM → SMEM → Registers (steps 2-5 only)

StepG→S CopyS→R CopyPipeline
1 (naive)N/A (direct GMEM→RMEM)N/ANone
2 (smem)CopyUniversalOp (sync)CopyUniversalOp__syncthreads
3 (async)CopyG2SOp (cp.async)CopyUniversalOpcommit/wait, 3 stages
4 (TC)CopyG2SOp (cp.async)LdMatrix8x8x16bOpcommit/wait, 3 stages
5 (vec)CopyG2SOp 128-bitLdMatrix + retilecommit/wait + reg buffer

5. How It All Connects (The Full Picture)

┌─────────────────────────────────────────────────────────────────────────────┐ │ │ │ HARDWARE PROBLEM YOUR CHOICES │ │ (can't change) (given to you) (engineering decisions) │ │ │ │ mma.sync.m16n8k16 M=4096 atoms_layout = (2,2,1) │ │ TV_layout (fixed) N=4096 permutation = (32,32,16) │ │ 32 threads/atom K=4096 BM,BN,BK = 128,128,32 │ │ 4 values/thread A: (M,K) M-major NUM_STAGES = 3 │ │ B: (N,K) N-major Copy atoms: cp.async + ldmatrix│ │ C: (M,N) row-major SMEM swizzle: Swizzle<3,3,3> │ │ │ │ │ │ │ │ │ └────────────────────┼───────────────────────┘ │ │ ↓ │ │ ┌─────────────────────┐ │ │ │ CuTe DERIVES: │ │ │ │ │ │ │ │ NUM_THREADS = 128 │ │ │ │ Grid = (32, 32, 1) │ │ │ │ Regs/thread = 64 f32│ │ │ │ SMEM = 3×(128+128) │ │ │ │ ×32×2 = 48 KB │ │ │ │ partition_C shape: │ │ │ │ (4,(2,2),(2,2)) │ │ │ │ K-iters: K/BK = 128 │ │ │ │ MMA calls/iter: 32 │ │ │ └─────────────────────┘ │ │ ↓ │ │ ┌─────────────────────┐ │ │ │ COMPILED KERNEL │ │ │ │ (PTX → GPU) │ │ │ └─────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────────────────┘

6. The Kernel's 4 Steps — Memory Space Perspective

Step A: TILE Memory: Global (DRAM, ~400 cycles) Input: mA, mB, mC — full matrices in global memory (ALL CTAs see these) Output: gA, gB, gC — this CTA's tile (views, same memory, different offset) What happens: local_tile(mA, tiler, coord=(bidx,bidy), proj) → Pointer arithmetic: base + block_offset → No data movement! Just a narrower view. ───────────────────────────────────────────────────────────────────────────── Step B: PARTITION Memory: Globaldescribes ownership Input: gC (128×128) — this CTA's output tile Output: tCgC (V, M_rest, N_rest) — this THREAD's elements (still in global) tCrC — register accumulator (zeroed) What happens: TiledMMA's TV layout + permutation → which elements are mine → No data movement! tCgC is a view. tCrC allocates registers. ───────────────────────────────────────────────────────────────────────────── Step C: K-LOOP Memory: GlobalSharedRegisters The actual data movement and computation: ┌────────────────────────────────────────────────────────────────────┐ │ for each k_tile: │ │ │ │ G→S: cp.async (DMA hardware copies GMEM → SMEM, async) │ │ 128 threads cooperate to fill (BM×BK + BN×BK) of SMEM │ │ │ │ S→R: ldmatrix (warp-level load, SMEM → registers) │ │ Each thread loads its MMA fragment from swizzled SMEM │ │ │ │ Compute: cute.gemm → mma.sync (tensor core MMA in registers) │ │ Accumulates: tCrC += tCrA × tCrB │ │ │ └────────────────────────────────────────────────────────────────────┘ ───────────────────────────────────────────────────────────────────────────── Step D: EPILOGUE Memory: RegistersGlobal Input: tCrC — accumulated results in registers (64 values per thread) Output: tCgC — written to global memory What happens: cute.copy(atom, tCrC, tCgC) → 64 store instructions per thread, targeting the addresses in tCgC's strides.

7. Naming Convention Cheat Sheet

PrefixMeaningMemoryScope
mFull matrixGlobalAll CTAs
gCTA-tiled viewGlobalOne CTA
sShared memory tensorSharedOne CTA
t_g_Thread-partitioned, globalGlobalOne thread
t_s_Thread-partitioned, sharedSharedOne thread
t_r_Thread-partitioned, registerRegistersOne thread
Full naming: t C g A │ │ │ │ │ │ │ └── which tensor (A, B, or C) │ │ └──── memory space (g=global, s=shared, r=register) │ └────── partitioned FOR which operand (C's M×N pattern, A's M×K, B's N×K) └──────── "t" = thread-partitioned (personal view)

8. The Performance Equation (Why These Choices Matter)

Roofline: your kernel is limited by either compute or memory bandwidth. Compute throughput: = NUM_THREADS × V × (atom ops per cycle) × SM clock → More atoms, bigger tiles → more compute per CTA Memory bandwidth demand: = (BM×BK + BN×BK) × sizeof(dtype) × (K/BK) / time → Bigger BK → fewer loads, better compute/memory ratio Arithmetic intensity: = 2×BM×BN×BK / ((BM×BK + BN×BK) × sizeof(dtype)) → Bigger tiles → higher intensity → less memory-bound Your choices directly control this balance: ┌──────────────┬─────────────────────────────────────────────┐ │ BM, BN │ ↑ = more compute per load = higher AI │ │ │ ↑ = more registers = lower occupancy │ ├──────────────┼─────────────────────────────────────────────┤ │ BK │ ↑ = fewer global loads │ │ │ ↑ = more SMEM per stage │ ├──────────────┼─────────────────────────────────────────────┤ │ NUM_STAGES │ ↑ = better latency hiding (overlap) │ │ │ ↑ = more SMEM total │ ├──────────────┼─────────────────────────────────────────────┤ │ atoms_layout │ More atoms → more threads → more occupancy │ │ │ But also more register pressure per SM │ ├──────────────┼─────────────────────────────────────────────┤ │ permutation │ Bigger → more work per thread → fewer iters │ │ │ But more accumulators → more registers │ └──────────────┴─────────────────────────────────────────────┘ The art of kernel design = finding the sweet spot in this multi-dimensional space.

9. Quick Reference: Step 1 vs Step 4

Step 1 (Naive SIMT)Step 4 (Tensor Core)
AtomMmaUniversalOp(f32) — 1×1×1MmaF16BF16Op(f16,f32,(16,8,16))
atoms_layout(16, 16, 1) = 256 atoms(2, 2, 1) = 4 atoms
Threads256 × 1 = 2564 × 32 = 128
Tile128×128×8128×128×32
Permutation(16,4):(4,1) — layout(32, 32, 16) — sizes
Data pathGMEM→RMEM→FMA→RMEM→GMEMGMEM→SMEM→RMEM→HMMA→RMEM→GMEM
Regs per thread (C)64 × f32 = 256 B64 × f32 = 256 B
SMEMNone3×(128+128)×32×2B = 48 KB
Peak FLOPSLow (scalar FMA)High (tensor core: 256 ops/cycle/SM)

10. The One-Page Summary

To design a CuTe GEMM kernel: 1. Accept what's fixed: The atom (hardware instruction + TV layout) 2. Accept what's given: Problem size, data layouts, data types 3. Choose your knobs: atoms_layout, permutation, BM/BN/BK, stages, copy atoms 4. Let CuTe derive: Thread mapping, partition shapes, index math, grid Your job = Step 3. Everything else is handled. The kernel code itself is a template: A. Tile (local_tile) → B. Partition → C. K-loop → D. Epilogue This structure is THE SAME for all 5 tutorial steps. Only the components plugged in change (atoms, copies, pipeline).