partition_C: Thread-to-Element Mapping Derivation

How CuTe transforms (128, 128):(128, 1) โ†’ (1, (4,2), (4,2)):(0, (128,8192), (1,64)) in 5 algebraic steps

๐Ÿ”ด The Problem partition_C Solves

Given:

Need:

For each thread, produce a view (pointer + layout) into gC that selects exactly the 64 elements this thread is responsible for โ€” without moving any data.

Why 64? 128 ร— 128 = 16,384 elements รท 256 threads = 64 elements per thread.

In raw CUDA you'd write: int row = threadIdx.x / 16; int col = threadIdx.x % 16; and manually compute which 8ร—8 sub-tile this thread owns. CuTe replaces all of that with layout algebra โ€” composable, type-safe, zero-cost abstractions that work for any MMA instruction, not just scalar FMA.

๐Ÿ“– Glossary: Abbreviations โ†’ Full Names

AbbreviationFull NameWhat It Is
MMAMatrix Multiply-AccumulateHardware instruction: D = A ร— B + C
FMAFused Multiply-AddScalar MMA: d = a ร— b + c
CTACooperative Thread Array= thread block in CUDA
TVThread-ValueLayout mapping (thread_idx, value_idx) โ†’ element offset
ThrThreadThread dimension in a layout
Frg / FragFragmentPer-thread owned data elements
V / ValValueElements within one atom execution
PermPermutationReordering of element positions
GMEMGlobal MemoryGPU DRAM (HBM)
SMEMShared MemoryOn-chip per-CTA scratchpad
RMEMRegister MemoryPer-thread register file
BM/BN/BKBlock M/N/KCTA tile dimensions
TiledMmaTiled MMAMMA atom replicated across threads + positions
thrfrg_CThread-Fragment CAlgorithm: partition C into thread and fragment dims

๐ŸŸฃ Core Principle: Layout Composition as Function Composition

The Key Idea

A CuTe Layout = Shape + Stride is a function from logical coordinates to memory offsets.

Layout(coord) = dot(coord, stride) โ†’ memory offset

For example, (128, 128):(128, 1) is the function:

f(m, n) = m ร— 128 + n ร— 1

Why Layout Algebra?

partition_C needs to answer: "Given that thread 0 owns M-positions {0,1,2,3,64,65,66,67} and N-positions {0,1,2,3,64,65,66,67}, what is the layout that maps a flat index 0..63 to the correct memory offsets?"

Instead of computing 64 offsets explicitly, CuTe composes layout functions. The output is a new layout (pure metadata, no data movement) whose strides encode the answer.

Three Primitive Operations

logical_divide(A, B)

Splits layout A using tiler B.
Math: A โˆ˜ (B, B*) where B* = complement of B.

Shape transforms: (M) โ†’ (TileM, RestM)

Like reshape(128) โ†’ (64, 2) but respecting the stride structure

zipped_divide(A, B)

= logical_divide + zip: groups all tile dims together and all rest dims together.

((TileM,RestM),(TileN,RestN))
โ†’ ((TileM,TileN),(RestM,RestN))

Makes the "atom" part and "rest" part independently addressable

compose(A, B)

Function composition: new_layout(x) = A(B(x))

Transforms coordinate space without touching data.

Like: B maps (thread, value) โ†’ atom coords; A maps atom coords โ†’ memory

๐Ÿ“ฅ Concrete Inputs (Step 1 Example)

The Tensor

gC = (128, 128):(128, 1) โ†‘ shape โ†‘ stride 128 rows ร— 128 cols, row-major f(m,n) = m*128 + n

The TiledMma

Atom: 1ร—1ร—1 scalar FMA AtomLayoutC: (1, 1):(0, 0) โ† TV layout 1 thread, 1 value per atom atoms_layout: (16, 16, 1):(16, 1, 0) 256 threads as 16ร—16 Mร—N grid permutation_M: (16, 4):(4, 1) permutation_N: (16, 4):(4, 1) product = 16ร—4 = 64 positions per dim 128/64 = 2 repetitions needed

What does permutation (16, 4):(4, 1) mean?

It maps a 2D coordinate (f, r) where fโˆˆ[0,16), rโˆˆ[0,4) to position: f*4 + r*1

For thread 0 (f=0): positions are {0, 1, 2, 3} โ€” the first 4 consecutive elements.

For thread 1 (f=1): positions are {4, 5, 6, 7} โ€” the next 4 elements.

This covers 64 positions total (0..63). The tile is 128, so there's a "Rest" factor of 2 for positions {64..127}.

๐Ÿ”„ The 5-Step Pipeline

INPUT
(128,128)
(128,1)
โ†’
โ‘  logical_divide
Split by perm
โ†’
โ‘ก zipped_divide
Group atom+rest
โ†’
โ‘ข compose
MN โ†’ TV
โ†’
โ‘ฃ zipped_divide
Group thr+frg
โ†’
โ‘ค slice
Pick thread 0
โ†’
OUTPUT
(1,(4,2),(4,2))
(0,(128,8192),(1,64))

โ‘  logical_divide โ€” Split Layout by Permutation

C++ Source (mma_atom.hpp:256-258)

auto t_tile = make_tile(permutation_mnk<0>(), // permutation_M permutation_mnk<1>()); // permutation_N auto t_tensor = logical_divide(ctensor, t_tile); // (PermM, PermN)

What It Does

Splits each dimension of gC according to the permutation layout. The permutation (16,4):(4,1) covers 64 of the 128 positions, leaving a "rest" factor of 2.

Math: A โˆ˜ (B, B*)

For the M dimension (size 128, stride 128):

B = permutation_M = (16, 4):(4, 1) โ€” covers 64 positions: {0,4,8,...,60, 1,5,...,61, 2,...,62, 3,...,63} B* = complement(B, 128) = (2):(64) โ€” the 2 remaining tiles: {0, 64} (B, B*) = ((16, 4), 2):((4, 1), 64) โ€” full factorization of 128 positions Compose with M-stride=128: stride((16, 4), 2) = ((4ร—128, 1ร—128), 64ร—128) = ((512, 128), 8192)

For the N dimension (size 128, stride 1):

Same structure, but base stride is 1 instead of 128: stride((16, 4), 2) = ((4ร—1, 1ร—1), 64ร—1) = ((4, 1), 64)
Input: (128, 128):(128, 1)
Output: ((16,4,2), (16,4,2)):((512,128,8192), (4,1,64))
Meaning: M split โ†’ (F=16 thread-groups, R=4 consecutive, Rest=2 repetitions)
N split โ†’ (F=16 thread-groups, R=4 consecutive, Rest=2 repetitions)

Visualization: M-dimension split

128 M-positions factored as (16 groups ร— 4 consecutive) ร— 2 repetitions

0
1
2
3
4
5
6
7
...
60
61
62
63
|
64
65
66
67
68
69
70
71
...
124
โ€ฆ
127
โ–  Thread 0 (F=0) โ–  Thread 1 (F=1) โ–  Thread 15 (F=15) | = Rest boundary (pos 64)

โ‘ก zipped_divide โ€” Group Atom Dimensions Together

C++ Source (mma_atom.hpp:261-263)

auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), // AtomM = 1 make_layout(size<1>(AtomShape_MNK{}))); // AtomN = 1 auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN))

What It Does

Groups "atom" dimensions (the elements consumed by one MMA instruction) into one mode, and "rest" dimensions (how many atom invocations are needed) into another.

Why?

Step 3 needs to apply the TV (Thread-Value) layout within the atom. By zipping atom dimensions together, the atom mode becomes a self-contained unit we can transform independently.

Derivation

c_tile = (1, 1) โ€” atom covers 1ร—1 elements (scalar FMA) zipped_divide splits by (1, 1), which is trivial: Each (16,4,2) subgroup in M: atom takes 1, rest gets (16,4,2) Each (16,4,2) subgroup in N: atom takes 1, rest gets (16,4,2) Then zip: atom parts โ†’ mode 0, rest parts โ†’ mode 1
Input: ((16,4,2), (16,4,2)):((512,128,8192), (4,1,64))
Output: ((1, 1), (16,4,2, 16,4,2)):((โ˜…, โ˜…), (512,128,8192, 4,1,64))
Modes: Mode 0 = Atom (1,1): one MMA invocation processes 1 element
Mode 1 = Rest (16,4,2,16,4,2): 16384 total atom invocations to cover the tile
โ˜… Note: Atom strides are "don't care" since atom is size 1ร—1.
For a tensor core atom (e.g., 16ร—8ร—8 HMMA), the atom mode would be (16, 8) with meaningful strides, and the rest would be smaller. The structure of the algorithm stays the same โ€” only the numbers change. This is the power of CuTe's generality.

โ‘ข compose โ€” Transform Atom from (M,N) to (Thread, Value)

C++ Source (mma_atom.hpp:265-266)

// Transform the Atom mode from (M,N) to (Thr,Val) auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{}, _); // ((ThrV,FrgV),(RestM,RestN))

What It Does

Applies the atom's TV (Thread-Value) layout to mode 0 (the atom). This re-labels atom coordinates from spatial (M, N) to functional (which thread, which value within that thread).

Why This Step Exists

The atom layout knows how a hardware MMA instruction distributes its inputs/outputs across threads. For example, a tensor core MMA might spread a 16ร—8 output across 32 threads, each holding 4 values. The TV layout encodes that hardware-specific mapping.

For our scalar FMA: 1 thread computes 1 value, so TV = (1,1):(0,0) โ€” trivial.

Derivation

AtomLayoutC_TV = (1, 1):(0, 0) โ†‘ 1 thread (ThrV), 1 value per atom (FrgV = Fragment Value) compose(atom_mode=(1,1), TV=(1,1)) โ†’ (ThrV=1, FrgV=1) This is just relabeling โ€” no numeric change. The "_" in compose(..., _) means: leave mode 1 (Rest) untouched.
Input: ((1, 1), (16,4,2, 16,4,2))
Output: ((ThrV=1, FrgV=1), (16,4,2, 16,4,2))
Strides: ((0, 0), (512,128,8192, 4,1,64)) โ€” unchanged numerically
Key: ThrV = threads within atom (intra-atom thread index)
FrgV = values per thread within atom (fragment values)
For a 16ร—8 tensor core (HMMA): AtomLayoutC_TV would be e.g., (32, 4):(some_strides) โ€” 32 threads each holding 4 values. Then compose would transform (16,8) โ†’ (32, 4) with non-trivial strides derived from the hardware's register layout. This is where CuTe encodes hardware knowledge.

โ‘ฃ zipped_divide โ€” Separate Thread-Indexing from Thread-Owned

C++ Source (mma_atom.hpp:268-272)

// Tile the tensor for the C-threads auto thr_tile = make_tile( _, // keep ThrV as-is make_tile(make_layout(size<1>(thr_layout_vmnk_)), // ThrM = 16 make_layout(size<2>(thr_layout_vmnk_)))); // ThrN = 16 auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))

What It Does

The "Rest" mode from step 3 contains both thread-indexing dimensions (which of the 16ร—16 threads) and per-thread-owned dimensions (the 4ร—2 elements each thread computes). This step separates them.

Why?

After this step, mode 0 = "which thread am I?" and mode 1 = "what data do I own?". The final slice (step 5) uses mode 0 to pick a thread and returns mode 1 as the result.

Derivation

thr_tile for mode 1 (Rest) = (ThrM=16, ThrN=16) Mode 1 was: (16, 4, 2, 16, 4, 2):(512, 128, 8192, 4, 1, 64) M-part N-part Divide M-rest (16, 4, 2) by ThrM=16: โ€ข Tile: 16 positions (strides from F dimension) โ†’ stride 512 โ€ข Rest: (4, 2) โ†’ strides (128, 8192) Divide N-rest (16, 4, 2) by ThrN=16: โ€ข Tile: 16 positions โ†’ stride 4 โ€ข Rest: (4, 2) โ†’ strides (1, 64) Zip โ†’ thread parts together, fragment parts together:
Input: ((ThrV=1, FrgV=1), (16,4,2, 16,4,2))
Output: ((1, (16, 16)), (1, (4,2), (4,2)))
Strides: ((0, (512, 4)), (0, (128,8192), (1,64)))
Meaning: Mode 0 = Thread selector: (ThrV=1, (ThrM=16, ThrN=16)) โ€” picks one of 256 threads
Mode 1 = Fragment: (FrgV=1, (RestM=(4,2)), (RestN=(4,2))) โ€” 64 elements this thread owns

Thread Indexing Strides Explained

Thread (tm, tn) where tm โˆˆ [0,16), tn โˆˆ [0,16): offset = tm ร— 512 + tn ร— 4 tm=0, tn=0: offset = 0 โ†’ row 0, col 0 tm=0, tn=1: offset = 4 โ†’ row 0, col 4 tm=1, tn=0: offset = 512 โ†’ row 4, col 0 (512/128 = row 4) tm=1, tn=1: offset = 516 โ†’ row 4, col 4 Thread 0 (tm=0, tn=0) starts at (row=0, col=0) Thread 1 (tm=0, tn=1) starts at (row=0, col=4) Thread 16 (tm=1, tn=0) starts at (row=4, col=0) Thread 255 (tm=15, tn=15) starts at (row=60, col=60)

โ‘ค slice โ€” Pick One Thread's Fragment

C++ Source (mma_atom.hpp:469-472)

// In partition_C: auto thr_vmn = make_coord(get<0>(thr_vmnk_), // thr_v = 0 make_coord(get<1>(thr_vmnk_), // thr_m get<2>(thr_vmnk_))); // thr_n return thr_tensor(thr_vmn, make_coord(_, repeat<...>(_))); // โ†’ fragment view

What It Does

Plugs in the concrete thread coordinates into mode 0 (the thread selector). This computes a pointer offset for this thread and returns mode 1 (the fragment) as the result layout.

Derivation for Thread 0

Thread 0: thr_vmnk = (0, 0, 0, 0) โ†’ thr_vmn = (v=0, (m=0, n=0)) Pointer offset = 0ร—0 + 0ร—512 + 0ร—4 = 0 (no shift โ€” thread 0 starts at gC's base) Remaining layout (mode 1) = the fragment: (1, (4,2), (4,2)):(0, (128,8192), (1,64))
Input: ((1,(16,16)), (1,(4,2),(4,2))) โ€” full thrfrg_C result
Slice: thread_coord = (0, (0, 0)), fragment_coord = (_, (_, _), (_, _))
Output: (1, (4,2), (4,2)):(0, (128,8192), (1,64)) โ€” thread 0's view of gC

What the Output Strides Mean Physically

tCgC[v, (m0,m1), (n0,n1)] โ†’ gC offset = vร—0 + m0ร—128 + m1ร—8192 + n0ร—1 + n1ร—64 The 64 elements thread 0 owns (m0โˆˆ[0,4), m1โˆˆ[0,2), n0โˆˆ[0,4), n1โˆˆ[0,2)): (m0,m1,n0,n1) โ†’ (row, col) = (m0 + m1ร—64, n0 + n1ร—64) n1=0 n1=1 n0: 0 1 2 3 n0: 0 1 2 3 m1=0 m0=0: (0,0) (0,1) (0,2) (0,3) (0,64) (0,65) (0,66) (0,67) m0=1: (1,0) (1,1) (1,2) (1,3) (1,64) (1,65) (1,66) (1,67) m0=2: (2,0) (2,1) (2,2) (2,3) (2,64) (2,65) (2,66) (2,67) m0=3: (3,0) (3,1) (3,2) (3,3) (3,64) (3,65) (3,66) (3,67) m1=1 m0=0: (64,0) (64,1) (64,2) (64,3) (64,64) (64,65) (64,66) (64,67) m0=1: (65,0) (65,1) (65,2) (65,3) (65,64) (65,65) (65,66) (65,67) m0=2: (66,0) (66,1) (66,2) (66,3) (66,64) (66,65) (66,66) (66,67) m0=3: (67,0) (67,1) (67,2) (67,3) (67,64) (67,65) (67,66) (67,67)

โœ“ Verification: 16ร—16 Thread Ownership Grid

Each of 256 threads owns a 4ร—4 block at position (tmร—4, tnร—4) and a mirror at (+64, +64). Four 4ร—4 blocks = 64 elements. Let's visualize thread ownership for the first 16ร—16 corner of the 128ร—128 C tile:

Each colored 4ร—4 block is one thread's first sub-block. Thread numbers shown. The second sub-block at (+64,+64) follows the same pattern.

Element Count Check

Thread 0 owns: 1 ร— (4ร—2) ร— (4ร—2) = 1 ร— 8 ร— 8 = 64 elements 256 threads ร— 64 = 16,384 = 128 ร— 128 โœ“ No element is claimed by two threads (permutation is a bijection). No element is unclaimed (complement covers the remainder).

๐ŸŽฏ Summary: Why Each Step Exists

StepOperationProblem It SolvesOutput Shape
โ‘  logical_divide Factorize the tile according to the thread permutation โ€” separate "which group" from "position within group" ((16,4,2),(16,4,2))
โ‘ก zipped_divide Isolate the atom's footprint (what one MMA instruction touches) from the rest (how many times we repeat) ((1,1),(16,4,2,16,4,2))
โ‘ข compose Re-label atom coordinates from spatial (M,N) to functional (Thread, Value) using hardware knowledge ((ThrV,FrgV),(Rest...))
โ‘ฃ zipped_divide Separate thread-indexing dimensions from per-thread fragment dimensions across the "rest" mode ((Thr),(Frg))
โ‘ค slice Plug in a concrete thread ID and return just that thread's fragment layout (1,(4,2),(4,2))

The Design Principle

CuTe decomposes the "which thread owns which elements" problem into composable algebraic primitives โ€” divide, zip, compose, slice โ€” each handling one concern:

  1. Spatial structure (how elements are arranged in memory)
  2. Hardware structure (how the MMA instruction maps threads to elements)
  3. Tiling structure (how many atoms tile across the output)

By keeping these orthogonal, the same code works for a scalar FMA atom, a 16ร—8ร—8 tensor core, or a 64ร—256ร—32 Blackwell UMMA โ€” only the atom's TV layout changes. The algorithm is invariant to the hardware.