Step 1: Thread Trace Deep Dive

Tracing Thread 0, Block (0,0) through the naive GEMM kernel
Problem: M=256, N=128, K=32 | Tile: 128×128×8 | 256 threads

What is a CTA? What is a CTA Tile?

CTA (Cooperative Thread Array) = CUDA Thread Block

A group of threads (here 256) that:

In this example: 1 CTA = 256 threads, assigned to compute one 128×128 chunk of C.

CTA Tile = The Work Assigned to One CTA

The cta_tiler = (128, 128, 8) defines how the problem is sliced:

Full C matrix (256×128): Grid = (2, 1) = 2 CTAs total N=128 columns ←———————————→ ┌────────────────────────┐ │ │ ← CTA (0,0): rows 0-127, cols 0-127 │ 128×128 tile │ 128×128 = 16384 elements │ = CTA (0,0) │ 256 threads → 64 elements/thread │ │ ┌────────────────────────┐ │ │ ← CTA (1,0): rows 128-255, cols 0-127 │ 128×128 tile │ │ = CTA (1,0) │ │ │ ┌────────────────────────┐

Step A: local_tile — Extract This CTA's Data

How local_tile works (3 operations in one call)

  1. proj — filter the 3D tiler to match the tensor's dimensions
  2. tile — logically divide the tensor into tiles
  3. coord — select which tile(s) to keep

Tiling A: local_tile(mA, (128,128,8), coord=(0,0,None), proj=(1,None,1))

mA: shape (256, 32), stride (1, 256) ← M-major: consecutive in M, stride-256 in K (A[m,k] at address: base + m + k*256) proj = (1, None, 1) "A has M and K, not N" tiler (128, 128, 8) → keep M=128, drop N, keep K=8 → effective tiler = (128, 8) Tile (256,32) by (128,8): M: 256/128 = 2 tiles K: 32/8 = 4 tiles Logical shape: (128, 8, 2_M_tiles, 4_K_tiles) coord = (bidx=0, _, None): M-tile index = 0 (first 128 rows) K-tile index = None (keep all 4 as iteration dimension) Result: gA shape (128, 8, 4), stride (1, 256, 2048) ↑ ↑ ↑ BM BK num_k_tiles

Tiling C: local_tile(mC, (128,128,8), coord=(0,0,None), proj=(1,1,None))

mC: shape (256, 128), stride (128, 1) ← row-major: C[m,n] at address: base + m*128 + n proj = (1, 1, None) "C has M and N, not K" tiler (128, 128, 8) → keep M=128, keep N=128, drop K → effective tiler = (128, 128) Tile (256,128) by (128,128): M: 256/128 = 2 tiles N: 128/128 = 1 tile coord = (bidx=0, bidy=0, _): Select M-tile 0, N-tile 0 Result: gC shape (128, 128), stride (128, 1) ↑ ↑ BM BN

Step B: partition_C — Thread 0's 64 Output Elements

How 256 threads map to the 128×128 tile

atoms_layout = (16, 16, 1):(16, 1, 0) Thread ID → atom position: thr_idx = M_atom * 16 + N_atom (from stride (16, 1)) Thread 0: M_atom=0, N_atom=0 Thread 1: M_atom=0, N_atom=1 Thread 16: M_atom=1, N_atom=0 Thread 255: M_atom=15, N_atom=15 16 atoms in M cover 128 rows: 128/16 = 8 rows per atom 16 atoms in N cover 128 cols: 128/16 = 8 cols per atom Each thread: 8 × 8 = 64 elements ✓

The Permutation Reshapes the 8 Elements

permutation_M = (16, 4):(4, 1) Without permutation: thread 0's 8 M-elements = rows 0,1,2,3,4,5,6,7 (consecutive) With permutation (16,4):(4,1): Group 1: 4 consecutive rows starting at offset 0 → rows 0, 1, 2, 3 Group 2: 4 consecutive rows starting at offset 64 → rows 64, 65, 66, 67 (stride 8192 / 128 = 64 rows gap) Why? Grouping in 4s = potential 128-bit (4×f32) vectorized stores

Thread 0's C Elements Visualized

gC tile (128×128), Thread 0 owns elements: Col: 0 1 2 3 ... 64 65 66 67 ... 127 ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ Row 0: ■ ■ ■ ■ ■ ■ ■ ■ ← 8 elements Row 1: ■ ■ ■ ■ ■ ■ ■ ■ Row 2: ■ ■ ■ ■ ■ ■ ■ ■ Row 3: ■ ■ ■ ■ ■ ■ ■ ■ ← 32 elements Row 4-63: (other threads) Row 64: ■ ■ ■ ■ ■ ■ ■ ■ Row 65: ■ ■ ■ ■ ■ ■ ■ ■ Row 66: ■ ■ ■ ■ ■ ■ ■ ■ Row 67: ■ ■ ■ ■ ■ ■ ■ ■ ← 64 total ✓ Row 68-127: (other threads) = 4 blocks of 4×4, arranged as 2 M-groups × 2 N-groups

The Output Tensor Shape Explained

tCgC: shape (1, (4,2), (4,2)), stride (0, (128, 8192), (1, 64)) ↑ ↑ ↑ V M N V = 1 Values per atom invocation. Scalar FMA produces 1 output. M = (4, 2) 4 = consecutive rows in each group (rows 0-3, then 64-67) 2 = number of groups Strides (128, 8192): 128 = row stride in C (next row = +128 in memory, since C is 128-wide) 8192 = gap between groups = 64 rows × 128 = 8192 N = (4, 2) 4 = consecutive cols in each group (cols 0-3, then 64-67) 2 = number of groups Strides (1, 64): 1 = adjacent column 64 = gap between groups = 64 columns apart

Step C: K-Loop — Thread 0 Loads A and B

Iteration Structure

K = 32, BK = 8 → 4 outer k_tiles Within each k_tile: 8 k_blocks (one per K-element, since atom is 1×1×1) for k_tile in range(4): ← outer loop: 4 iterations load A[:,k_tile], B[:,k_tile] from global memory to registers for k_block in range(8): ← inner loop: 8 FMAs per element pair cute.gemm(...) ← 64 FMAs per k_block (8M × 8N accumulations) Total FMAs: 4 × 8 × 64 = 2048 = 64 outputs × 32 K-reductions ✓

partition_A for Thread 0, k_tile=0

gA[:, :, 0] → shape (128, 8), stride (1, 256) ← first 8 columns of A for this CTA tCgA = thr_mma.partition_A(gA[:,:,0]) → shape (1, (4,2), 8), stride (0, (1,64), 256) V M-rows K-cols Thread 0 needs the same M-rows it owns in C: {0,1,2,3, 64,65,66,67} And ALL 8 K-columns (to compute the dot product). A (128 rows × 8 cols), thread 0 reads: k=0 k=1 k=2 k=3 k=4 k=5 k=6 k=7 Row 0: [ • • • • • • • • ] ← 8 values Row 1: [ • • • • • • • • ] Row 2: [ • • • • • • • • ] Row 3: [ • • • • • • • • ] ... (rows 4-63: other threads) Row 64:[ • • • • • • • • ] Row 65:[ • • • • • • • • ] Row 66:[ • • • • • • • • ] Row 67:[ • • • • • • • • ] Total per k_tile: 8 rows × 8 cols = 64 elements Strides: (1, 64, 256) 1 = M-major (consecutive rows in memory) 64 = gap between M-groups (row 0 to row 64 = offset 64, since stride-in-M = 1) 256 = K-stride (next column = +256 in original A, since A is 256-tall)

partition_B for Thread 0, k_tile=0

gB[:, :, 0] → shape (128, 8), stride (1, 128) ← B is N-major (128 tall) tCgB = thr_mma.partition_B(gB[:,:,0]) → shape (1, (4,2), 8), stride (0, (1,64), 128) V N-cols K-cols Thread 0 needs the same N-columns it owns in C: {0,1,2,3, 64,65,66,67} And ALL 8 K-columns. Same pattern as A: 8 rows of B (which are N-indices) × 8 K-elements = 64 values.

The Inner Compute: What Happens Per k_block

k_block = 0 (first of 8 within the BK=8 tile): tCrA[:, :, 0] = shape (1, (4,2)) = 8 A-values: A[m, k=0] for m in {0,1,2,3,64,65,66,67} tCrB[:, :, 0] = shape (1, (4,2)) = 8 B-values: B[n, k=0] for n in {0,1,2,3,64,65,66,67} tCrC = shape (1, (4,2), (4,2)) = 64 accumulators cute.gemm does (for scalar FMA atom): ┌─────────────────────────────────────────────────────────┐ │ for each m in {0..7}: (8 M-values) │ │ for each n in {0..7}: (8 N-values) │ │ C[m,n] += A[m] * B[n] (1 FMA) │ │ │ │ = 64 FMAs for this k_block │ └─────────────────────────────────────────────────────────┘ After all 8 k_blocks in this k_tile: 512 FMAs After all 4 k_tiles: 2048 FMAs total for thread 0

Step D: Epilogue — Store Results

cute.copy(atom, tCrC, tCgC) Source: tCrC — 64 values in registers, shape (1,(4,2),(4,2)) Dest: tCgC — 64 locations in global C, stride (0,(128,8192),(1,64)) The copy atom (CopyUniversalOp) = simple store instruction. Strides in tCgC tell copy where each register value goes in global memory. Register[v=0, m=(i,g_m), n=(j,g_n)] → C[row, col] where: row = i + g_m * 64 (i in 0..3, g_m in 0..1) col = j + g_n * 64 (j in 0..3, g_n in 0..1) address = base + row*128 + col

Complete Flow for Thread 0

Setup (host): TiledMMA = atom(1×1×1 FMA) × layout(16×16) + permutation((16,4),(16,4)) Grid = (2, 1, 1), Block = (256, 1, 1) Kernel (thread 0, block 0,0): A. local_tile: mA(256,32) → gA(128, 8, 4) this CTA's A chunk, 4 K-tiles mC(256,128) → gC(128, 128) this CTA's output B. partition_C: gC(128,128) → tCgC(1,(4,2),(4,2)) thread 0's 64 output slots allocate tCrC = 64 registers, zeroed C. K-loop (×4 tiles): gA(:,:,k) → partition_A → tCgA(1,(4,2),8) → copy to tCrA (registers) gB(:,:,k) → partition_B → tCgB(1,(4,2),8) → copy to tCrB (registers) Inner loop (×8 k_blocks): cute.gemm: tCrC += tCrA[:,:,kb] outer_product tCrB[:,:,kb] (64 scalar FMAs per k_block) D. Epilogue: cute.copy: tCrC(registers) → tCgC(global memory) 64 stores to the 4 scattered 4×4 blocks in C

Summary Table

QuantityValueDerivation
CTAs (blocks)2ceil(256/128) × ceil(128/128) = 2×1
Threads per CTA25616×16 atom grid
C elements per thread64128×128 / 256
C element layout4 blocks of 4×4permutation (16,4)×(16,4)
K-tiles (outer loop)432 / 8
K-blocks (inner loop)8BK=8, atom K=1
A loads per k_tile648 rows × 8 K-elements
B loads per k_tile648 rows × 8 K-elements
FMAs total204864 outputs × K=32
Registers for accum64 × f32= 256 bytes per thread

Naming Convention Reference

NameBreakdownMeaning
tCgCt-C-g-Cthread-partitioned, using C's partition, in global mem, of tensor C
tCgAt-C-g-Athread-partitioned, using C's partition, in global mem, of tensor A
tCrCt-C-r-Cthread-partitioned, using C's partition, in registers, of tensor C
tCrAt-C-r-Athread-partitioned, using C's partition, in registers, of tensor A
gAg-ACTA-tiled view of A in global memory
gCg-CCTA-tiled view of C in global memory

Why is A partitioned "using C's scheme"?

Because the MMA output (C) determines which (M,N) coordinates each thread owns. That thread must load the matching A rows (same M-indices) and B columns (same N-indices). So partition_A uses the M-component of C's thread mapping.