CuTe GEMM Kernel Design — The Big Picture
A roadmap for designing a high-performance GEMM kernel using CuTe DSL
1. The High-Level Flow (30-Second Overview)
You want: C[M,N] = A[M,K] × B[N,K] (matrix multiply)
You write:
Host (@cute.jit):
1. Pick an MMA atom ← what hardware instruction?
2. Pick atoms_layout ← how many warps?
3. Pick permutation ← how to tile across the CTA?
4. Pick tile sizes (BM,BN,BK) ← how big is each thread block's work?
5. Launch kernel
Kernel (@cute.kernel):
A. local_tile → extract this CTA's data from global memory
B. partition_C → figure out this thread's output elements
C. K-loop → load data, call cute.gemm, accumulate
D. epilogue → store results back to global memory
CuTe handles:
All index math, thread→element mapping, address computation.
You just describe "what" — CuTe derives "how."
2. The Layer Diagram — What's Fixed, What's Chosen, What's Derived
FIXED BY HARDWARE
Layer 1: Hardware (NVIDIA ISA)
The MMA instruction's properties — baked in silicon, can't change.
| Instruction | fma.rn.f32 or mma.sync.m16n8k16 or wgmma.m64n256k16 |
| Atom shape (M,N,K) | (1,1,1) or (16,8,16) or (64,256,16) |
| Threads per atom | 1 or 32 (warp) or 128 (warpgroup) |
| Values per thread (V) | 1 or 4 or more |
| TV_layout strides | The exact thread→element wiring |
| Data types | fp32, fp16→fp32, fp16→fp16, etc. |
↓ encoded as
FIXED (wrapper)
Layer 2: Atom (CuTe Encoding)
CuTe's layout-based encoding of the hardware truth. One-to-one mapping.
| TV_layout_C | ((4,8),(2,2)):((16,1),(8,64)) for m16n8k16 |
| TV_layout_A | How A-elements map to threads (ISA-defined) |
| TV_layout_B | How B-elements map to threads (ISA-defined) |
↓ + ↓
GIVEN BY PROBLEM
Problem Inputs
What your application hands you. Not free choices.
| Matrix dimensions | M, N, K (e.g., 4096×4096×4096) |
| Data layouts | A: (M,K):(1,M), B: (N,K):(1,N), C: (M,N):(N,1) |
| Data types | fp16 inputs, fp32 accumulator, fp16 output |
↓ your engineering decisions ↓
YOUR DESIGN CHOICES
Layer 3: TiledMMA + Kernel Config
The tuning knobs you control. This is where kernel optimization lives.
atoms_layout | How many atoms (warps) per CTA — e.g., (2,2,1) = 4 warps |
permutation | Coverage expansion + element pattern — e.g., (32,32,16) |
BM, BN, BK | CTA tile size — e.g., 128×128×32 |
NUM_STAGES | Pipeline depth — e.g., 3 async stages |
| Copy atoms | Which load/store instructions — cp.async, ldmatrix, etc. |
| SMEM layout | Swizzle pattern for bank-conflict-free access |
↓ CuTe derives everything below ↓
AUTO-DERIVED
Derived Quantities (You Don't Compute These)
| Thread→element mapping | From atom TV + atoms_layout + permutation |
partition_C/A/B shapes | (V, M_rest, N_rest) — each thread's view |
| Registers per thread | V × M_rest × N_rest (accumulator count) |
| Grid dimensions | ceil(M/BM) × ceil(N/BN) |
| K-iterations | K / BK (outer), BK / atom_K (inner) |
| SMEM size | (BM×BK + BN×BK) × dtype_size × NUM_STAGES |
| Copy tiling | How threads cooperate on G→S and S→R transfers |
↓ produces ↓
THE KERNEL
Running Kernel
A compiled PTX kernel that executes the GEMM with all the above baked in as constants.
3. The Kernel Structure (What You Actually Write)
━━━ HOST SIDE (@cute.jit) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┌─────────────────────────────────────────────────────┐
│ 1. Create MMA atom (pick the hardware instruction) │
│ 2. Create TiledMMA (atom + layout + permutation) │
│ 3. Create copy atoms (G→S, S→R copy instructions) │
│ 4. Define SMEM layouts (with swizzle) │
│ 5. Compute grid = ceil_div(problem, tile) │
│ 6. Launch kernel │
└─────────────────────────────────────────────────────┘
━━━ KERNEL SIDE (@cute.kernel) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
┌─────────────────────────────────────────────────────┐
│ A. TILE — Extract this CTA's work from global memory │
│ gA = local_tile(mA, tiler, coord, proj) │
│ gB = local_tile(mB, tiler, coord, proj) │
│ gC = local_tile(mC, tiler, coord, proj) │
├─────────────────────────────────────────────────────┤
│ B. PARTITION — What does THIS thread own? │
│ thr_mma = tiled_mma.get_slice(tidx) │
│ tCgC = thr_mma.partition_C(gC) │
│ tCrC = make_fragment_C(tCgC) → accumulators │
│ tCrC.fill(0) │
├─────────────────────────────────────────────────────┤
│ C. K-LOOP — The main compute loop │
│ for k_tile in range(num_k_tiles): │
│ [Load A,B: GMEM → SMEM → Registers] │
│ [Compute: cute.gemm(tiled_mma, ...)] │
│ [Pipeline: commit/wait for next tile] │
├─────────────────────────────────────────────────────┤
│ D. EPILOGUE — Store results │
│ cute.copy(atom, tCrC, tCgC) → registers→GMEM │
└─────────────────────────────────────────────────────┘
4. The Design Decision Tree
When designing a GEMM kernel, decisions flow in this order:
1
Choose the Atom (Which Instruction?)
Based on your GPU architecture and data types:
| GPU | Data Type | Atom | Shape | Threads |
| Any | fp32 | MmaUniversalOp | (1,1,1) | 1 |
| SM80 (Ampere) | fp16 | MmaF16BF16Op | (16,8,16) | 32 |
| SM90 (Hopper) | fp16 | GMMA | (64,N,16) | 128 |
| SM100 (Blackwell) | fp16/fp4 | UMMA | (varies) | 128 |
This is usually obvious from your target GPU and data type.
2
Choose atoms_layout (How Many Warps?)
Decide how many atoms work together in one CTA:
NUM_THREADS = product(atoms_layout) × threads_per_atom
(2, 2, 1) × 32 = 128 threads (4 warps) ← step 4
(4, 1, 1) × 32 = 128 threads (4 warps) ← alternative: all in M
(1, 4, 1) × 32 = 128 threads (4 warps) ← alternative: all in N
Tradeoffs:
More atoms in M → better A-data reuse (threads share B-rows)
More atoms in N → better B-data reuse (threads share A-rows)
More total atoms → more threads → potentially better latency hiding
3
Choose Tile Size (BM, BN, BK)
How much work per CTA. The central performance knob:
Constraints:
SMEM size: (BM×BK + BN×BK) × sizeof(dtype) × NUM_STAGES ≤ available SMEM
Registers: V × (BM/perm_M) × (BN/perm_N) × sizeof(acc) ≤ max registers
Divisibility: BM divisible by atom_M × atoms_M, same for BN, BK
Typical choices (Ampere, fp16):
128×128×32 (balanced, step 4)
128×256×32 (more N, fewer blocks, higher register pressure)
256×128×32 (more M, good for tall matrices)
4
Choose Permutation (How to Tile Atoms Across the CTA Tile)
Bridge between native atom coverage and full tile:
SIMT (scalar FMA)
Permutation = Layout
Choose R (group size) for coalescing:
R = 4 for fp32 (128-bit stores)
R = 8 for fp16 (128-bit stores)
permutation_M = (F, R):(R, 1)
Tensor Core (mma.sync)
Permutation = Integer sizes
Choose total coverage per iteration:
perm_M = atoms_M × atom_M (= native)
perm_N = atoms_N × atom_N × expand
perm_K = atoms_K × atom_K
expand ≥ 1 (more = more regs)
5
Choose Data Movement Strategy
How data flows: GMEM → SMEM → Registers (steps 2-5 only)
| Step | G→S Copy | S→R Copy | Pipeline |
| 1 (naive) | N/A (direct GMEM→RMEM) | N/A | None |
| 2 (smem) | CopyUniversalOp (sync) | CopyUniversalOp | __syncthreads |
| 3 (async) | CopyG2SOp (cp.async) | CopyUniversalOp | commit/wait, 3 stages |
| 4 (TC) | CopyG2SOp (cp.async) | LdMatrix8x8x16bOp | commit/wait, 3 stages |
| 5 (vec) | CopyG2SOp 128-bit | LdMatrix + retile | commit/wait + reg buffer |
5. How It All Connects (The Full Picture)
┌─────────────────────────────────────────────────────────────────────────────┐
│ │
│ HARDWARE PROBLEM YOUR CHOICES │
│ (can't change) (given to you) (engineering decisions) │
│ │
│ mma.sync.m16n8k16 M=4096 atoms_layout = (2,2,1) │
│ TV_layout (fixed) N=4096 permutation = (32,32,16) │
│ 32 threads/atom K=4096 BM,BN,BK = 128,128,32 │
│ 4 values/thread A: (M,K) M-major NUM_STAGES = 3 │
│ B: (N,K) N-major Copy atoms: cp.async + ldmatrix│
│ C: (M,N) row-major SMEM swizzle: Swizzle<3,3,3> │
│ │
│ │ │ │ │
│ └────────────────────┼───────────────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ CuTe DERIVES: │ │
│ │ │ │
│ │ NUM_THREADS = 128 │ │
│ │ Grid = (32, 32, 1) │ │
│ │ Regs/thread = 64 f32│ │
│ │ SMEM = 3×(128+128) │ │
│ │ ×32×2 = 48 KB │ │
│ │ partition_C shape: │ │
│ │ (4,(2,2),(2,2)) │ │
│ │ K-iters: K/BK = 128 │ │
│ │ MMA calls/iter: 32 │ │
│ └─────────────────────┘ │
│ ↓ │
│ ┌─────────────────────┐ │
│ │ COMPILED KERNEL │ │
│ │ (PTX → GPU) │ │
│ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
6. The Kernel's 4 Steps — Memory Space Perspective
Step A: TILE Memory: Global (DRAM, ~400 cycles)
Input: mA, mB, mC — full matrices in global memory (ALL CTAs see these)
Output: gA, gB, gC — this CTA's tile (views, same memory, different offset)
What happens: local_tile(mA, tiler, coord=(bidx,bidy), proj)
→ Pointer arithmetic: base + block_offset
→ No data movement! Just a narrower view.
─────────────────────────────────────────────────────────────────────────────
Step B: PARTITION Memory: Global → describes ownership
Input: gC (128×128) — this CTA's output tile
Output: tCgC (V, M_rest, N_rest) — this THREAD's elements (still in global)
tCrC — register accumulator (zeroed)
What happens: TiledMMA's TV layout + permutation → which elements are mine
→ No data movement! tCgC is a view. tCrC allocates registers.
─────────────────────────────────────────────────────────────────────────────
Step C: K-LOOP Memory: Global → Shared → Registers
The actual data movement and computation:
┌────────────────────────────────────────────────────────────────────┐
│ for each k_tile: │
│ │
│ G→S: cp.async (DMA hardware copies GMEM → SMEM, async) │
│ 128 threads cooperate to fill (BM×BK + BN×BK) of SMEM │
│ │
│ S→R: ldmatrix (warp-level load, SMEM → registers) │
│ Each thread loads its MMA fragment from swizzled SMEM │
│ │
│ Compute: cute.gemm → mma.sync (tensor core MMA in registers) │
│ Accumulates: tCrC += tCrA × tCrB │
│ │
└────────────────────────────────────────────────────────────────────┘
─────────────────────────────────────────────────────────────────────────────
Step D: EPILOGUE Memory: Registers → Global
Input: tCrC — accumulated results in registers (64 values per thread)
Output: tCgC — written to global memory
What happens: cute.copy(atom, tCrC, tCgC)
→ 64 store instructions per thread, targeting the addresses in tCgC's strides.
7. Naming Convention Cheat Sheet
| Prefix | Meaning | Memory | Scope |
m | Full matrix | Global | All CTAs |
g | CTA-tiled view | Global | One CTA |
s | Shared memory tensor | Shared | One CTA |
t_g_ | Thread-partitioned, global | Global | One thread |
t_s_ | Thread-partitioned, shared | Shared | One thread |
t_r_ | Thread-partitioned, register | Registers | One thread |
Full naming: t C g A
│ │ │ │
│ │ │ └── which tensor (A, B, or C)
│ │ └──── memory space (g=global, s=shared, r=register)
│ └────── partitioned FOR which operand (C's M×N pattern, A's M×K, B's N×K)
└──────── "t" = thread-partitioned (personal view)
8. The Performance Equation (Why These Choices Matter)
Roofline: your kernel is limited by either compute or memory bandwidth.
Compute throughput:
= NUM_THREADS × V × (atom ops per cycle) × SM clock
→ More atoms, bigger tiles → more compute per CTA
Memory bandwidth demand:
= (BM×BK + BN×BK) × sizeof(dtype) × (K/BK) / time
→ Bigger BK → fewer loads, better compute/memory ratio
Arithmetic intensity:
= 2×BM×BN×BK / ((BM×BK + BN×BK) × sizeof(dtype))
→ Bigger tiles → higher intensity → less memory-bound
Your choices directly control this balance:
┌──────────────┬─────────────────────────────────────────────┐
│ BM, BN │ ↑ = more compute per load = higher AI │
│ │ ↑ = more registers = lower occupancy │
├──────────────┼─────────────────────────────────────────────┤
│ BK │ ↑ = fewer global loads │
│ │ ↑ = more SMEM per stage │
├──────────────┼─────────────────────────────────────────────┤
│ NUM_STAGES │ ↑ = better latency hiding (overlap) │
│ │ ↑ = more SMEM total │
├──────────────┼─────────────────────────────────────────────┤
│ atoms_layout │ More atoms → more threads → more occupancy │
│ │ But also more register pressure per SM │
├──────────────┼─────────────────────────────────────────────┤
│ permutation │ Bigger → more work per thread → fewer iters │
│ │ But more accumulators → more registers │
└──────────────┴─────────────────────────────────────────────┘
The art of kernel design = finding the sweet spot in this multi-dimensional space.
9. Quick Reference: Step 1 vs Step 4
| Step 1 (Naive SIMT) | Step 4 (Tensor Core) |
| Atom | MmaUniversalOp(f32) — 1×1×1 | MmaF16BF16Op(f16,f32,(16,8,16)) |
| atoms_layout | (16, 16, 1) = 256 atoms | (2, 2, 1) = 4 atoms |
| Threads | 256 × 1 = 256 | 4 × 32 = 128 |
| Tile | 128×128×8 | 128×128×32 |
| Permutation | (16,4):(4,1) — layout | (32, 32, 16) — sizes |
| Data path | GMEM→RMEM→FMA→RMEM→GMEM | GMEM→SMEM→RMEM→HMMA→RMEM→GMEM |
| Regs per thread (C) | 64 × f32 = 256 B | 64 × f32 = 256 B |
| SMEM | None | 3×(128+128)×32×2B = 48 KB |
| Peak FLOPS | Low (scalar FMA) | High (tensor core: 256 ops/cycle/SM) |
10. The One-Page Summary
To design a CuTe GEMM kernel:
1. Accept what's fixed: The atom (hardware instruction + TV layout)
2. Accept what's given: Problem size, data layouts, data types
3. Choose your knobs: atoms_layout, permutation, BM/BN/BK, stages, copy atoms
4. Let CuTe derive: Thread mapping, partition shapes, index math, grid
Your job = Step 3. Everything else is handled.
The kernel code itself is a template:
A. Tile (local_tile) → B. Partition → C. K-loop → D. Epilogue
This structure is THE SAME for all 5 tutorial steps.
Only the components plugged in change (atoms, copies, pipeline).