Step 1: Thread Trace Deep Dive
Tracing Thread 0, Block (0,0) through the naive GEMM kernel
Problem: M=256, N=128, K=32 | Tile: 128×128×8 | 256 threads
What is a CTA? What is a CTA Tile?
CTA (Cooperative Thread Array) = CUDA Thread Block
A group of threads (here 256) that:
- Execute on the same SM (Streaming Multiprocessor)
- Share the same shared memory (48KB~164KB)
- Can synchronize with
__syncthreads()
- Are scheduled together as a unit
In this example: 1 CTA = 256 threads, assigned to compute one 128×128 chunk of C.
CTA Tile = The Work Assigned to One CTA
The cta_tiler = (128, 128, 8) defines how the problem is sliced:
- M=128 rows of the output per CTA
- N=128 columns of the output per CTA
- K=8 reduction elements processed per iteration
Full C matrix (256×128): Grid = (2, 1) = 2 CTAs total
N=128 columns
←———————————→
┌────────────────────────┐
│ │ ← CTA (0,0): rows 0-127, cols 0-127
│ 128×128 tile │ 128×128 = 16384 elements
│ = CTA (0,0) │ 256 threads → 64 elements/thread
│ │
┌────────────────────────┐
│ │ ← CTA (1,0): rows 128-255, cols 0-127
│ 128×128 tile │
│ = CTA (1,0) │
│ │
┌────────────────────────┐
Step A: local_tile — Extract This CTA's Data
How local_tile works (3 operations in one call)
- proj — filter the 3D tiler to match the tensor's dimensions
- tile — logically divide the tensor into tiles
- coord — select which tile(s) to keep
Tiling A: local_tile(mA, (128,128,8), coord=(0,0,None), proj=(1,None,1))
mA: shape (256, 32), stride (1, 256) ← M-major: consecutive in M, stride-256 in K
(A[m,k] at address: base + m + k*256)
proj = (1, None, 1) "A has M and K, not N"
tiler (128, 128, 8) → keep M=128, drop N, keep K=8 → effective tiler = (128, 8)
Tile (256,32) by (128,8):
M: 256/128 = 2 tiles K: 32/8 = 4 tiles
Logical shape: (128, 8, 2_M_tiles, 4_K_tiles)
coord = (bidx=0, _, None):
M-tile index = 0 (first 128 rows)
K-tile index = None (keep all 4 as iteration dimension)
Result: gA shape (128, 8, 4), stride (1, 256, 2048)
↑ ↑ ↑
BM BK num_k_tiles
Tiling C: local_tile(mC, (128,128,8), coord=(0,0,None), proj=(1,1,None))
mC: shape (256, 128), stride (128, 1) ← row-major: C[m,n] at address: base + m*128 + n
proj = (1, 1, None) "C has M and N, not K"
tiler (128, 128, 8) → keep M=128, keep N=128, drop K → effective tiler = (128, 128)
Tile (256,128) by (128,128):
M: 256/128 = 2 tiles N: 128/128 = 1 tile
coord = (bidx=0, bidy=0, _):
Select M-tile 0, N-tile 0
Result: gC shape (128, 128), stride (128, 1)
↑ ↑
BM BN
Step B: partition_C — Thread 0's 64 Output Elements
How 256 threads map to the 128×128 tile
atoms_layout = (16, 16, 1):(16, 1, 0)
Thread ID → atom position:
thr_idx = M_atom * 16 + N_atom (from stride (16, 1))
Thread 0: M_atom=0, N_atom=0
Thread 1: M_atom=0, N_atom=1
Thread 16: M_atom=1, N_atom=0
Thread 255: M_atom=15, N_atom=15
16 atoms in M cover 128 rows: 128/16 = 8 rows per atom
16 atoms in N cover 128 cols: 128/16 = 8 cols per atom
Each thread: 8 × 8 = 64 elements ✓
The Permutation Reshapes the 8 Elements
permutation_M = (16, 4):(4, 1)
Without permutation: thread 0's 8 M-elements = rows 0,1,2,3,4,5,6,7 (consecutive)
With permutation (16,4):(4,1):
Group 1: 4 consecutive rows starting at offset 0 → rows 0, 1, 2, 3
Group 2: 4 consecutive rows starting at offset 64 → rows 64, 65, 66, 67
(stride 8192 / 128 = 64 rows gap)
Why? Grouping in 4s = potential 128-bit (4×f32) vectorized stores
Thread 0's C Elements Visualized
gC tile (128×128), Thread 0 owns ■ elements:
Col: 0 1 2 3 ... 64 65 66 67 ... 127
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
Row 0: ■ ■ ■ ■ ■ ■ ■ ■ ← 8 elements
Row 1: ■ ■ ■ ■ ■ ■ ■ ■
Row 2: ■ ■ ■ ■ ■ ■ ■ ■
Row 3: ■ ■ ■ ■ ■ ■ ■ ■ ← 32 elements
Row 4-63: (other threads)
Row 64: ■ ■ ■ ■ ■ ■ ■ ■
Row 65: ■ ■ ■ ■ ■ ■ ■ ■
Row 66: ■ ■ ■ ■ ■ ■ ■ ■
Row 67: ■ ■ ■ ■ ■ ■ ■ ■ ← 64 total ✓
Row 68-127: (other threads)
= 4 blocks of 4×4, arranged as 2 M-groups × 2 N-groups
The Output Tensor Shape Explained
tCgC: shape (1, (4,2), (4,2)), stride (0, (128, 8192), (1, 64))
↑ ↑ ↑
V M N
V = 1
Values per atom invocation. Scalar FMA produces 1 output.
M = (4, 2)
4 = consecutive rows in each group (rows 0-3, then 64-67)
2 = number of groups
Strides (128, 8192):
128 = row stride in C (next row = +128 in memory, since C is 128-wide)
8192 = gap between groups = 64 rows × 128 = 8192
N = (4, 2)
4 = consecutive cols in each group (cols 0-3, then 64-67)
2 = number of groups
Strides (1, 64):
1 = adjacent column
64 = gap between groups = 64 columns apart
Step C: K-Loop — Thread 0 Loads A and B
Iteration Structure
K = 32, BK = 8 → 4 outer k_tiles
Within each k_tile: 8 k_blocks (one per K-element, since atom is 1×1×1)
for k_tile in range(4): ← outer loop: 4 iterations
load A[:,k_tile], B[:,k_tile] from global memory to registers
for k_block in range(8): ← inner loop: 8 FMAs per element pair
cute.gemm(...) ← 64 FMAs per k_block (8M × 8N accumulations)
Total FMAs: 4 × 8 × 64 = 2048 = 64 outputs × 32 K-reductions ✓
partition_A for Thread 0, k_tile=0
gA[:, :, 0] → shape (128, 8), stride (1, 256) ← first 8 columns of A for this CTA
tCgA = thr_mma.partition_A(gA[:,:,0])
→ shape (1, (4,2), 8), stride (0, (1,64), 256)
V M-rows K-cols
Thread 0 needs the same M-rows it owns in C: {0,1,2,3, 64,65,66,67}
And ALL 8 K-columns (to compute the dot product).
A (128 rows × 8 cols), thread 0 reads:
k=0 k=1 k=2 k=3 k=4 k=5 k=6 k=7
Row 0: [ • • • • • • • • ] ← 8 values
Row 1: [ • • • • • • • • ]
Row 2: [ • • • • • • • • ]
Row 3: [ • • • • • • • • ]
... (rows 4-63: other threads)
Row 64:[ • • • • • • • • ]
Row 65:[ • • • • • • • • ]
Row 66:[ • • • • • • • • ]
Row 67:[ • • • • • • • • ]
Total per k_tile: 8 rows × 8 cols = 64 elements
Strides: (1, 64, 256)
1 = M-major (consecutive rows in memory)
64 = gap between M-groups (row 0 to row 64 = offset 64, since stride-in-M = 1)
256 = K-stride (next column = +256 in original A, since A is 256-tall)
partition_B for Thread 0, k_tile=0
gB[:, :, 0] → shape (128, 8), stride (1, 128) ← B is N-major (128 tall)
tCgB = thr_mma.partition_B(gB[:,:,0])
→ shape (1, (4,2), 8), stride (0, (1,64), 128)
V N-cols K-cols
Thread 0 needs the same N-columns it owns in C: {0,1,2,3, 64,65,66,67}
And ALL 8 K-columns.
Same pattern as A: 8 rows of B (which are N-indices) × 8 K-elements = 64 values.
The Inner Compute: What Happens Per k_block
k_block = 0 (first of 8 within the BK=8 tile):
tCrA[:, :, 0] = shape (1, (4,2)) = 8 A-values: A[m, k=0] for m in {0,1,2,3,64,65,66,67}
tCrB[:, :, 0] = shape (1, (4,2)) = 8 B-values: B[n, k=0] for n in {0,1,2,3,64,65,66,67}
tCrC = shape (1, (4,2), (4,2)) = 64 accumulators
cute.gemm does (for scalar FMA atom):
┌─────────────────────────────────────────────────────────┐
│ for each m in {0..7}: (8 M-values) │
│ for each n in {0..7}: (8 N-values) │
│ C[m,n] += A[m] * B[n] (1 FMA) │
│ │
│ = 64 FMAs for this k_block │
└─────────────────────────────────────────────────────────┘
After all 8 k_blocks in this k_tile: 512 FMAs
After all 4 k_tiles: 2048 FMAs total for thread 0
Step D: Epilogue — Store Results
cute.copy(atom, tCrC, tCgC)
Source: tCrC — 64 values in registers, shape (1,(4,2),(4,2))
Dest: tCgC — 64 locations in global C, stride (0,(128,8192),(1,64))
The copy atom (CopyUniversalOp) = simple store instruction.
Strides in tCgC tell copy where each register value goes in global memory.
Register[v=0, m=(i,g_m), n=(j,g_n)] → C[row, col]
where:
row = i + g_m * 64 (i in 0..3, g_m in 0..1)
col = j + g_n * 64 (j in 0..3, g_n in 0..1)
address = base + row*128 + col
Complete Flow for Thread 0
Setup (host):
TiledMMA = atom(1×1×1 FMA) × layout(16×16) + permutation((16,4),(16,4))
Grid = (2, 1, 1), Block = (256, 1, 1)
Kernel (thread 0, block 0,0):
A. local_tile:
mA(256,32) → gA(128, 8, 4) this CTA's A chunk, 4 K-tiles
mC(256,128) → gC(128, 128) this CTA's output
B. partition_C:
gC(128,128) → tCgC(1,(4,2),(4,2)) thread 0's 64 output slots
allocate tCrC = 64 registers, zeroed
C. K-loop (×4 tiles):
gA(:,:,k) → partition_A → tCgA(1,(4,2),8) → copy to tCrA (registers)
gB(:,:,k) → partition_B → tCgB(1,(4,2),8) → copy to tCrB (registers)
Inner loop (×8 k_blocks):
cute.gemm: tCrC += tCrA[:,:,kb] outer_product tCrB[:,:,kb]
(64 scalar FMAs per k_block)
D. Epilogue:
cute.copy: tCrC(registers) → tCgC(global memory)
64 stores to the 4 scattered 4×4 blocks in C
Summary Table
| Quantity | Value | Derivation |
| CTAs (blocks) | 2 | ceil(256/128) × ceil(128/128) = 2×1 |
| Threads per CTA | 256 | 16×16 atom grid |
| C elements per thread | 64 | 128×128 / 256 |
| C element layout | 4 blocks of 4×4 | permutation (16,4)×(16,4) |
| K-tiles (outer loop) | 4 | 32 / 8 |
| K-blocks (inner loop) | 8 | BK=8, atom K=1 |
| A loads per k_tile | 64 | 8 rows × 8 K-elements |
| B loads per k_tile | 64 | 8 rows × 8 K-elements |
| FMAs total | 2048 | 64 outputs × K=32 |
| Registers for accum | 64 × f32 | = 256 bytes per thread |
Naming Convention Reference
| Name | Breakdown | Meaning |
tCgC | t-C-g-C | thread-partitioned, using C's partition, in global mem, of tensor C |
tCgA | t-C-g-A | thread-partitioned, using C's partition, in global mem, of tensor A |
tCrC | t-C-r-C | thread-partitioned, using C's partition, in registers, of tensor C |
tCrA | t-C-r-A | thread-partitioned, using C's partition, in registers, of tensor A |
gA | g-A | CTA-tiled view of A in global memory |
gC | g-C | CTA-tiled view of C in global memory |
Why is A partitioned "using C's scheme"?
Because the MMA output (C) determines which (M,N) coordinates each thread owns. That thread must load the matching A rows (same M-indices) and B columns (same N-indices). So partition_A uses the M-component of C's thread mapping.