Each thread gets four 4×4 blocks = 4 × 16 = 64 elements.
The Permutation: The ONE Design Choice
Everything above comes from a single parameter: the permutation layout(16, 4):(4, 1)
permutation = (Groups, Chunk):(group_stride, chunk_stride)
Groups = how many groups (one per thread) (16)
Chunk = elements per group (4)
group_stride = distance between group starts (4)
chunk_stride = distance between elements in a group (1)
Thread k, element r → position = k × group_stride + r × chunk_stride
Example: (16, 4):(4, 1)
Thread 0: 0×4 + {0,1,2,3}×1 = {0, 1, 2, 3}
Thread 1: 1×4 + {0,1,2,3}×1 = {4, 5, 6, 7}
Thread 15: 15×4 + {0,1,2,3}×1 = {60,61,62,63}
Covers 16×4 = 64 of 128 → rest = 128/64 = 2
→ each thread gets a SECOND group at +64:
Thread 0 → also positions {64, 65, 66, 67}
Why this matters:
The permutation is the only free parameter the kernel author picks. Everything else is derived:
Derived from
Result
Permutation + tile size
Thread ownership pattern
Ownership + tensor strides
Final layout strides
Atom TV layout
V dimension (hardware-fixed)
Alternative permutations (same tile, same thread count):
(16,4):(4,1) → 4 consecutive, no gap ← step1
(16,4):(1,16) → 4 elements spaced 16 apart
(16,8):(8,1) → 8 consecutive (covers full 128, rest=1)
Constraint: F must be divisible by ThrM (=16). Different permutation → same element count, different access pattern → different performance.
Permutation Comparison: What Different Choices Look Like
Below shows ONE dimension (M=128 positions) distributed among 16 threads. Each row = one permutation choice. Colors = thread ownership.
(of 16 threads along M-dimension)
Key insight: All permutations give each thread exactly 8 positions (128 ÷ 16 = 8). But the spatial arrangement differs — consecutive elements vs. scattered elements determine cache line utilization, vectorization potential, and bank-conflict behavior.
logical_divide Results for Each Permutation
Applying Step ① logical_divide(gC_M=128:128, permutation) to the M-dimension. Same process applies to N.
For M-dimension with base stride = 128 (row-major, 128 cols per row):
(16,4):(4,1) applied to size 128, stride 128:
F=16, stride = sF × base = 4 × 128 = 512 ← thread spacing in memory
R=4, stride = sR × base = 1 × 128 = 128 ← element spacing (= 1 row)
rest=2, stride = perm_size × base = 64 × 128 = 8192 ← repetition offset
(16,4):(1,16) applied to size 128, stride 128:
F=16, stride = sF × base = 1 × 128 = 128 ← threads are adjacent rows!
R=4, stride = sR × base = 16 × 128 = 2048 ← elements skip 16 rows
rest=2, stride = 64 × 128 = 8192 ← same repetition offset
(16,8):(8,1) applied to size 128, stride 128:
F=16, stride = sF × base = 8 × 128 = 1024 ← thread spacing (8 rows apart)
R=8, stride = sR × base = 1 × 128 = 128 ← element spacing (= 1 row)
rest=1, no rest dimension (perm covers all 128)
Pattern: The stride formula is always output_stride = permutation_stride × input_base_stride.
The permutation strides (4,1), (1,16), (8,1) get multiplied by the tensor's base stride (128 for M-rows, 1 for N-cols) to produce the physical memory strides in the logical_divide result.
Full 128×128 Tile — All 256 Threads Color-Coded
Hover over any pixel to see which thread owns it. Click a thread in the controls to highlight it.
—
← col 0 ... col 127 →
↕ row 0 ... row 127
⚠ Each colored square = 1 element (one float32 value). Drawn as 4×4 pixels for visibility.
Scale note: Each colored square in the grid above = 1 output element (one C[row][col] value), rendered as a 4×4 pixel block for visibility. The grid is 128×128 elements = 512×512 pixels. No grouping — each pixel-block is one float.
Zoomed In: Thread 0's 64 Elements (Actual Cell View)
Below we zoom into thread 0's four 4×4 blocks. Each labeled cell = one C[row][col] element = one float = one FMA result.
Block A (top-left): rows 0-3, cols 0-3
Block C (bottom-left): rows 64-67, cols 0-3
Block B (top-right): rows 0-3, cols 64-67
Block D (bottom-right): rows 64-67, cols 64-67
Total: 4 blocks × 16 elements/block = 64 elements = exactly what thread 0 computes and stores.
Each cell shows C[row][col]. The memory offset = row × 128 + col (row-major).
Where these 4 blocks sit in the full 128×128 tile:
Blue = thread 0's 64 elements (4 tiny squares), dark = other 255 threads' elements. Each blue pixel = 1 element.
1:1 Mapping — CuTe Layout ↔ Physical Elements
Left: CuTe's output layout for thread 0. Right: the actual C[row][col] it points to. Click any row to highlight that element in the grid.
That's 2 row-groups × 2 col-groups × 4×4 = 64 elements per thread
Why this specific pattern?
Reason 1: Load balancing — every thread gets exactly 64 elements ✓
Reason 2: Spatial locality — 4 consecutive rows × 4 consecutive cols =
16-element blocks that are adjacent in memory
Reason 3: No bank conflicts — threads in the same warp access
different cache lines (stride-4 separation in both dims)
Could you use a simpler pattern? Yes! For example, each thread could own 64 consecutive elements in a row. But that would give terrible cache utilization when loading A and B. The 4×4 block pattern balances across both dimensions.
Equivalent CUDA Code (No CuTe)
__global__ void gemm_step1(float* C, float* A, float* B, int M, int N, int K) {
// --- THE MAPPING (this is what partition_C computes) ---
int tm = threadIdx.x / 16; // my row-group index (0..15)
int tn = threadIdx.x % 16; // my col-group index (0..15)
// My 8 rows: 4 consecutive + 4 more at offset 64
int my_rows[8];
for (int i = 0; i < 4; i++) my_rows[i] = tm * 4 + i; // rows tm*4 .. tm*4+3
for (int i = 0; i < 4; i++) my_rows[4 + i] = tm * 4 + 64 + i; // rows tm*4+64 .. tm*4+67
// My 8 cols: same pattern
int my_cols[8];
for (int i = 0; i < 4; i++) my_cols[i] = tn * 4 + i;
for (int i = 0; i < 4; i++) my_cols[4 + i] = tn * 4 + 64 + i;
// --- ACCUMULATOR (64 registers) ---
float acc[8][8] = {0};
// --- K-LOOP ---
for (int k = 0; k < K; k++) {
for (int mi = 0; mi < 8; mi++) {
float a_val = A[my_rows[mi] * K + k];
for (int ni = 0; ni < 8; ni++) {
acc[mi][ni] += a_val * B[my_cols[ni] * K + k]; // B is (N,K) row-major
}
}
}
// --- STORE ---
for (int mi = 0; mi < 8; mi++)
for (int ni = 0; ni < 8; ni++)
C[my_rows[mi] * N + my_cols[ni]] = acc[mi][ni];
}
That's it. The my_rows and my_cols arrays ARE what partition_C computes — just encoded as a layout (strides) instead of explicit arrays.
CuTe's Layout Encoding: Strides = Compressed Index Arrays
Step 1: Understand one dimension — (4, 2):(1, 64)
Explicit indices (what CUDA does)
my_rows = [0, 1, 2, 3, 64, 65, 66, 67]
my_cols = [0, 1, 2, 3, 64, 65, 66, 67]
// To get memory offset for element (mi, ni):
offset = my_rows[mi] * 128 + my_cols[ni]
Values per atom invocation. Scalar FMA → 1 value. Stride=0 means "no movement" (size is 1, so doesn't matter)
Just 1 element (no iteration)
1
M (Row)
(4, 2)
(128, 8192)
4 consecutive rows (stride 128 = one row in row-major C) 2 groups (stride 8192 = 64 rows × 128 cols/row)
rows: 0,1,2,3,64,65,66,67
2
N (Col)
(4, 2)
(1, 64)
4 consecutive cols (stride 1 = one column in row-major) 2 groups (stride 64 = jump 64 columns)
cols: 0,1,2,3,64,65,66,67
Step 3: How to read nested shape (4, 2) with stride (128, 8192)
Shape (4, 2) means: two levels of iteration, like a nested loop.
// Equivalent Python:
positions = []
for i1 in range(2): # outer shape = 2
for i0 in range(4): # inner shape = 4
pos = i0 * 128 + i1 * 8192 # inner_stride=128, outer_stride=8192
positions.append(pos)
# Result:
positions = [0, 128, 256, 384, # i1=0: rows 0,1,2,3
8192, 8320, 8448, 8576] # i1=1: rows 64,65,66,67# Convert to row numbers (÷128):
rows = [0, 1, 2, 3, 64, 65, 66, 67] ✓ matches CUDA my_rows!
Step 4: Putting all 3 modes together
// The full offset formula — just dot-product of indices and strides:
offset(v, (m0, m1), (n0, n1)) = v×0 + m0×128 + m1×8192 + n0×1 + n1×64// Total elements = product of all shapes: 1 × (4×2) × (4×2) = 1 × 8 × 8 = 64// This is a 3D tensor of shape [1][8][8] stored with custom strides// Example: element [0][(2,1)][(3,0)]// v=0, m0=2, m1=1, n0=3, n1=0// offset = 0 + 2×128 + 1×8192 + 3×1 + 0×64 = 0 + 256 + 8192 + 3 + 0 = 8451// row = 8451 / 128 = 66, col = 8451 % 128 = 3// → C[66][3] ✓ (row 66 is in group m1=1, col 3 is in group n1=0)
Summary — what each piece tells you:
Layout piece
Question it answers
CUDA equivalent
Shape (1, (4,2), (4,2))
How many elements per thread? 1×8×8 = 64
Loop bounds: for mi in 0..8, for ni in 0..8
Stride 0 (mode V)
N/A — only 1 value per atom
(doesn't appear in CUDA)
Stride (128, 8192) (mode M)
How far apart are my rows in memory?
my_rows[i]*N (N=128)
Stride (1, 64) (mode N)
How far apart are my cols in memory?
my_cols[j]
Step 5: The (V, M, N) Convention — Always This Order
The output of partition_C/A/Balways follows a fixed mode ordering. This is a hard-coded design choice in CuTe, not something that varies:
Function
Output layout
Mode 0
Mode 1
Mode 2
partition_C(gC)
(V, M_rest, N_rest)
V = values per atom
M = rows this thread owns
N = cols this thread owns
partition_A(gA)
(V, M_rest, K_rest)
V = values per atom
M = rows this thread owns
K = reduction positions
partition_B(gB)
(V, N_rest, K_rest)
V = values per atom
N = cols this thread owns
K = reduction positions
// Why this matters: cute.gemm() relies on this fixed ordering to know:// - Mode 0 (V): loop over values within one atom invocation// - Mode 2 of A and B (K): the shared dimension to contract over// - Mode 1 of A matches Mode 1 of C (both are M-rows)// - Mode 1 of B matches Mode 2 of C (both are N-cols)//// cute.gemm(mma, tCrA, tCrB, tCrC):// for k in K_rest: ← mode 2 of A/B// tCrC[v,m,n] += tCrA[v,m,k] × tCrB[v,n,k]
Stride Annotation on the Visual Grid
The grid below shows thread 0's 8×8 elements with stride arrows showing what each stride physically means — how far you jump in memory when incrementing each index.
Interactive: Pick Any Thread
Change the thread ID below to see its 4 blocks light up in the grid above.