Step1 GEMM: Thread → Element Mapping

128×128 output tile, 256 threads, scalar FMA — no algebra, just the picture

The Setup (3 numbers, that's all)

WhatValuePlain English
Output tile128 × 128 = 16,384 elementsOne thread block computes this chunk of C
Threads per block256Arranged as a 16×16 grid
Elements per thread16,384 ÷ 256 = 64Each thread accumulates 64 FMA results

The one question: How to distribute 64 elements to each thread?

There are many valid answers. Step1 uses this particular pattern:

Thread (tm, tn) owns rows {tm×4 .. tm×4+3, tm×4+64 .. tm×4+67} × cols {tn×4 .. tn×4+3, tn×4+64 .. tn×4+67}

Where: tm = threadIdx.x / 16 (0..15), tn = threadIdx.x % 16 (0..15)

Each thread gets four 4×4 blocks = 4 × 16 = 64 elements.

The Permutation: The ONE Design Choice

Everything above comes from a single parameter: the permutation layout (16, 4):(4, 1)

permutation = (Groups, Chunk):(group_stride, chunk_stride) Groups = how many groups (one per thread) (16) Chunk = elements per group (4) group_stride = distance between group starts (4) chunk_stride = distance between elements in a group (1) Thread k, element r → position = k × group_stride + r × chunk_stride Example: (16, 4):(4, 1) Thread 0: 0×4 + {0,1,2,3}×1 = {0, 1, 2, 3} Thread 1: 1×4 + {0,1,2,3}×1 = {4, 5, 6, 7} Thread 15: 15×4 + {0,1,2,3}×1 = {60,61,62,63} Covers 16×4 = 64 of 128 → rest = 128/64 = 2 → each thread gets a SECOND group at +64: Thread 0 → also positions {64, 65, 66, 67}

Why this matters:

The permutation is the only free parameter the kernel author picks. Everything else is derived:

Derived fromResult
Permutation + tile sizeThread ownership pattern
Ownership + tensor stridesFinal layout strides
Atom TV layoutV dimension (hardware-fixed)

Alternative permutations (same tile, same thread count):

(16,4):(4,1) → 4 consecutive, no gap ← step1 (16,4):(1,16) → 4 elements spaced 16 apart (16,8):(8,1) → 8 consecutive (covers full 128, rest=1)

Constraint: F must be divisible by ThrM (=16). Different permutation → same element count, different access pattern → different performance.

Permutation Comparison: What Different Choices Look Like

Below shows ONE dimension (M=128 positions) distributed among 16 threads. Each row = one permutation choice. Colors = thread ownership.

(of 16 threads along M-dimension)

Key insight: All permutations give each thread exactly 8 positions (128 ÷ 16 = 8). But the spatial arrangement differs — consecutive elements vs. scattered elements determine cache line utilization, vectorization potential, and bank-conflict behavior.

logical_divide Results for Each Permutation

Applying Step ① logical_divide(gC_M=128:128, permutation) to the M-dimension. Same process applies to N.

Permutation Perm size (F×R) Rest = 128/(F×R) logical_divide result (M only) Physical meaning of strides
(16,4):(4,1) 64 2 (16, 4, 2):(512, 128, 8192) 16 thread-groups × 4 consecutive rows × 2 repetitions at +64 rows
Thread 0 → rows {0,1,2,3, 64,65,66,67}
(16,4):(1,16) 64 2 (16, 4, 2):(128, 2048, 8192) 16 thread-groups × 4 rows spaced 16 apart × 2 repetitions at +64 rows
Thread 0 → rows {0,16,32,48, 64,80,96,112}
(16,8):(8,1) 128 1 (16, 8):(1024, 128) 16 thread-groups × 8 consecutive rows × no rest (full coverage)
Thread 0 → rows {0,1,2,3,4,5,6,7}

How the Strides Are Computed

For M-dimension with base stride = 128 (row-major, 128 cols per row): (16,4):(4,1) applied to size 128, stride 128: F=16, stride = sF × base = 4 × 128 = 512 ← thread spacing in memory R=4, stride = sR × base = 1 × 128 = 128 ← element spacing (= 1 row) rest=2, stride = perm_size × base = 64 × 128 = 8192 ← repetition offset (16,4):(1,16) applied to size 128, stride 128: F=16, stride = sF × base = 1 × 128 = 128 ← threads are adjacent rows! R=4, stride = sR × base = 16 × 128 = 2048 ← elements skip 16 rows rest=2, stride = 64 × 128 = 8192 ← same repetition offset (16,8):(8,1) applied to size 128, stride 128: F=16, stride = sF × base = 8 × 128 = 1024 ← thread spacing (8 rows apart) R=8, stride = sR × base = 1 × 128 = 128 ← element spacing (= 1 row) rest=1, no rest dimension (perm covers all 128)

Pattern: The stride formula is always output_stride = permutation_stride × input_base_stride.

The permutation strides (4,1), (1,16), (8,1) get multiplied by the tensor's base stride (128 for M-rows, 1 for N-cols) to produce the physical memory strides in the logical_divide result.

Full 128×128 Tile — All 256 Threads Color-Coded

Hover over any pixel to see which thread owns it. Click a thread in the controls to highlight it.

← col 0 ... col 127 →
↕ row 0 ... row 127
⚠ Each colored square = 1 element (one float32 value). Drawn as 4×4 pixels for visibility.

Scale note: Each colored square in the grid above = 1 output element (one C[row][col] value), rendered as a 4×4 pixel block for visibility. The grid is 128×128 elements = 512×512 pixels. No grouping — each pixel-block is one float.

Zoomed In: Thread 0's 64 Elements (Actual Cell View)

Below we zoom into thread 0's four 4×4 blocks. Each labeled cell = one C[row][col] element = one float = one FMA result.

Block A (top-left): rows 0-3, cols 0-3

Block C (bottom-left): rows 64-67, cols 0-3

Block B (top-right): rows 0-3, cols 64-67

Block D (bottom-right): rows 64-67, cols 64-67

Total: 4 blocks × 16 elements/block = 64 elements = exactly what thread 0 computes and stores.

Each cell shows C[row][col]. The memory offset = row × 128 + col (row-major).

Where these 4 blocks sit in the full 128×128 tile:

Blue = thread 0's 64 elements (4 tiny squares), dark = other 255 threads' elements. Each blue pixel = 1 element.

1:1 Mapping — CuTe Layout ↔ Physical Elements

Left: CuTe's output layout for thread 0. Right: the actual C[row][col] it points to. Click any row to highlight that element in the grid.

CuTe layout for thread 0: tCgC = (1, (4,2), (4,2)):(0, (128,8192), (1,64))

Formula: offset = v×0 + m0×128 + m1×8192 + n0×1 + n1×64

Then: row = offset / 128, col = offset % 128

All 64 Elements Enumerated

v is always 0 (scalar atom). Showing (m0, m1, n0, n1) → offset → C[row][col]

Visual: CuTe indices → 2D position

Each cell shows its CuTe flat index (0..63). Color = which block (m1, n1).

■ Block A (m1=0,n1=0)  ■ Block B (m1=0,n1=1)  ■ Block C (m1=1,n1=0)  ■ Block D (m1=1,n1=1)

How CuTe Derives This Layout

The 5 transformation steps that produce (1,(4,2),(4,2)):(0,(128,8192),(1,64))

INPUT — the CTA's output tile
gC = (128, 128):(128, 1)
128 rows × 128 cols, row-major (stride 128 per row, stride 1 per col)
① logical_divide by permutation (16,4):(4,1)
Split 128 into 64 + rest(2): how 16 threads tile 128 positions
→ ((16,4,2), (16,4,2))
  :((512,128,8192), (4,1,64))
blue=thread-indexing part   gold=rest (per-thread owned)
② zipped_divide by atom (1,1)
Separate atom footprint from repetitions. Atom=1×1 (trivial for scalar)
→ ((1,1), (16,4,2, 16,4,2))
  :((★,★), (512,128,8192, 4,1,64))
★ = don't-care (atom is size 1)
③ compose with AtomLayout_TV = (1,1):(0,0)
Relabel atom (M,N) → (Thread, Value). Trivial here (1 thr, 1 val)
→ ((Thr=1,Val=1), (16,4,2, 16,4,2))
  :((0,0), (512,128,8192, 4,1,64))
For tensor cores: TV would be (32,4) — non-trivial!
④ zipped_divide by thread layout (1,(16,16))
Separate "which thread" from "what it owns"
→ ((1,(16,16)), (1,(4,2),(4,2)))
  :((0,(512,4)), (0,(128,8192),(1,64)))
mode 0 = thread selector   mode 1 = fragment (per-thread)
⑤ slice thread 0 → (v=0, tm=0, tn=0)
Plug in thread coords, return fragment layout only
tCgC = (1,(4,2),(4,2)):(0,(128,8192),(1,64))
ptr offset = 0×0 + 0×512 + 0×4 = 0 (thread 0 starts at base)
shape: V=1 × M=8 × N=8 = 64 elements ← matches table on left!
Connection to the left:
  • (4,2):(128,8192) → the 8 row indices in the table
  • (4,2):(1,64) → the 8 col indices in the table
  • Cross-product of rows × cols = 64 cells shown left

Reading the mapping:

CuTe index (m0=2, m1=0, n0=3, n1=1): offset = 2×128 + 0×8192 + 3×1 + 1×64 = 256 + 0 + 3 + 64 = 323 row = 323 / 128 = 2, col = 323 % 128 = 67 → C[2][67] ✓ (row 2 is in Block A rows, col 67 is in Block B cols → Block B) CuTe index (m0=1, m1=1, n0=2, n1=0): offset = 1×128 + 1×8192 + 2×1 + 0×64 = 128 + 8192 + 2 + 0 = 8322 row = 8322 / 128 = 65, col = 8322 % 128 = 2 → C[65][2] ✓ (row 65 is in Block C rows, col 2 is in Block A cols → Block C)

The Pattern Rule (Plain English)

Think of it as a checkerboard with 4×4 blocks:

  1. Divide 128 rows into 32 groups of 4 rows each (groups 0..31)
  2. Divide 128 cols into 32 groups of 4 cols each (groups 0..31)
  3. Thread (tm, tn) owns row-groups {tm, tm+16} and col-groups {tn, tn+16}
  4. That's 2 row-groups × 2 col-groups × 4×4 = 64 elements per thread

Why this specific pattern?

Reason 1: Load balancing — every thread gets exactly 64 elements ✓ Reason 2: Spatial locality — 4 consecutive rows × 4 consecutive cols = 16-element blocks that are adjacent in memory Reason 3: No bank conflicts — threads in the same warp access different cache lines (stride-4 separation in both dims)

Could you use a simpler pattern? Yes! For example, each thread could own 64 consecutive elements in a row. But that would give terrible cache utilization when loading A and B. The 4×4 block pattern balances across both dimensions.

Equivalent CUDA Code (No CuTe)

__global__ void gemm_step1(float* C, float* A, float* B, int M, int N, int K) { // --- THE MAPPING (this is what partition_C computes) --- int tm = threadIdx.x / 16; // my row-group index (0..15) int tn = threadIdx.x % 16; // my col-group index (0..15) // My 8 rows: 4 consecutive + 4 more at offset 64 int my_rows[8]; for (int i = 0; i < 4; i++) my_rows[i] = tm * 4 + i; // rows tm*4 .. tm*4+3 for (int i = 0; i < 4; i++) my_rows[4 + i] = tm * 4 + 64 + i; // rows tm*4+64 .. tm*4+67 // My 8 cols: same pattern int my_cols[8]; for (int i = 0; i < 4; i++) my_cols[i] = tn * 4 + i; for (int i = 0; i < 4; i++) my_cols[4 + i] = tn * 4 + 64 + i; // --- ACCUMULATOR (64 registers) --- float acc[8][8] = {0}; // --- K-LOOP --- for (int k = 0; k < K; k++) { for (int mi = 0; mi < 8; mi++) { float a_val = A[my_rows[mi] * K + k]; for (int ni = 0; ni < 8; ni++) { acc[mi][ni] += a_val * B[my_cols[ni] * K + k]; // B is (N,K) row-major } } } // --- STORE --- for (int mi = 0; mi < 8; mi++) for (int ni = 0; ni < 8; ni++) C[my_rows[mi] * N + my_cols[ni]] = acc[mi][ni]; }

That's it. The my_rows and my_cols arrays ARE what partition_C computes — just encoded as a layout (strides) instead of explicit arrays.

CuTe's Layout Encoding: Strides = Compressed Index Arrays

Step 1: Understand one dimension — (4, 2):(1, 64)

Explicit indices (what CUDA does)

my_rows = [0, 1, 2, 3, 64, 65, 66, 67] my_cols = [0, 1, 2, 3, 64, 65, 66, 67] // To get memory offset for element (mi, ni): offset = my_rows[mi] * 128 + my_cols[ni]

Stride encoding (what CuTe does)

shape = (4, 2) // 4 consecutive, 2 groups stride = (1, 64) // step 1 between consec, jump 64 // row[mi] = (mi%4)*1 + (mi/4)*64 // mi=0: 0, mi=1: 1, mi=2: 2, mi=3: 3 // mi=4: 64, mi=5: 65, mi=6: 66, mi=7: 67 ✓

Step 2: The full 3-mode layout — (1, (4,2), (4,2)):(0, (128,8192), (1,64))

CuTe layouts are hierarchical: a shape can contain nested tuples. Read it as 3 modes:

Layout: (1, (4, 2), (4, 2)) : (0, (128, 8192), (1, 64)) ↑↑↑↑↑↑ ↑↑↑↑↑↑ Mode 0 Mode 1 Mode 2
ModeNameShapeStrideWhat it meansElements generated
0V (Value)10 Values per atom invocation. Scalar FMA → 1 value.
Stride=0 means "no movement" (size is 1, so doesn't matter)
Just 1 element (no iteration)
1M (Row)(4, 2)(128, 8192) 4 consecutive rows (stride 128 = one row in row-major C)
2 groups (stride 8192 = 64 rows × 128 cols/row)
rows: 0,1,2,3,64,65,66,67
2N (Col)(4, 2)(1, 64) 4 consecutive cols (stride 1 = one column in row-major)
2 groups (stride 64 = jump 64 columns)
cols: 0,1,2,3,64,65,66,67

Step 3: How to read nested shape (4, 2) with stride (128, 8192)

Shape (4, 2) means: two levels of iteration, like a nested loop. // Equivalent Python: positions = [] for i1 in range(2): # outer shape = 2 for i0 in range(4): # inner shape = 4 pos = i0 * 128 + i1 * 8192 # inner_stride=128, outer_stride=8192 positions.append(pos) # Result: positions = [0, 128, 256, 384, # i1=0: rows 0,1,2,3 8192, 8320, 8448, 8576] # i1=1: rows 64,65,66,67 # Convert to row numbers (÷128): rows = [0, 1, 2, 3, 64, 65, 66, 67] ✓ matches CUDA my_rows!

Step 4: Putting all 3 modes together

// The full offset formula — just dot-product of indices and strides: offset(v, (m0, m1), (n0, n1)) = v×0 + m0×128 + m1×8192 + n0×1 + n1×64 // Total elements = product of all shapes: 1 × (4×2) × (4×2) = 1 × 8 × 8 = 64 // This is a 3D tensor of shape [1][8][8] stored with custom strides // Example: element [0][(2,1)][(3,0)] // v=0, m0=2, m1=1, n0=3, n1=0 // offset = 0 + 2×128 + 1×8192 + 3×1 + 0×64 = 0 + 256 + 8192 + 3 + 0 = 8451 // row = 8451 / 128 = 66, col = 8451 % 128 = 3 // → C[66][3] ✓ (row 66 is in group m1=1, col 3 is in group n1=0)

Summary — what each piece tells you:

Layout pieceQuestion it answersCUDA equivalent
Shape (1, (4,2), (4,2))How many elements per thread? 1×8×8 = 64Loop bounds: for mi in 0..8, for ni in 0..8
Stride 0 (mode V)N/A — only 1 value per atom(doesn't appear in CUDA)
Stride (128, 8192) (mode M)How far apart are my rows in memory?my_rows[i]*N (N=128)
Stride (1, 64) (mode N)How far apart are my cols in memory?my_cols[j]

Step 5: The (V, M, N) Convention — Always This Order

The output of partition_C/A/B always follows a fixed mode ordering. This is a hard-coded design choice in CuTe, not something that varies:

FunctionOutput layoutMode 0Mode 1Mode 2
partition_C(gC) (V, M_rest, N_rest) V = values per atom M = rows this thread owns N = cols this thread owns
partition_A(gA) (V, M_rest, K_rest) V = values per atom M = rows this thread owns K = reduction positions
partition_B(gB) (V, N_rest, K_rest) V = values per atom N = cols this thread owns K = reduction positions
// Why this matters: cute.gemm() relies on this fixed ordering to know: // - Mode 0 (V): loop over values within one atom invocation // - Mode 2 of A and B (K): the shared dimension to contract over // - Mode 1 of A matches Mode 1 of C (both are M-rows) // - Mode 1 of B matches Mode 2 of C (both are N-cols) // // cute.gemm(mma, tCrA, tCrB, tCrC): // for k in K_rest: ← mode 2 of A/B // tCrC[v,m,n] += tCrA[v,m,k] × tCrB[v,n,k]

Stride Annotation on the Visual Grid

The grid below shows thread 0's 8×8 elements with stride arrows showing what each stride physically means — how far you jump in memory when incrementing each index.

Interactive: Pick Any Thread

Change the thread ID below to see its 4 blocks light up in the grid above.

tm=0, tn=0 → rows {0..3, 64..67}, cols {0..3, 64..67}