How CuTe transforms (128, 128):(128, 1) โ (1, (4,2), (4,2)):(0, (128,8192), (1,64)) in 5 algebraic steps
๐ด The Problem partition_C Solves
Given:
A 128ร128 output tile gC in global memory (row-major)
256 threads that need to collaboratively compute this tile
A Multiply-Accumulate (MMA) instruction (here: scalar FMA โ Fused Multiply-Add)
Need:
For each thread, produce a view (pointer + layout) into gC that selects exactly the 64 elements this thread is responsible for โ without moving any data.
Why 64? 128 ร 128 = 16,384 elements รท 256 threads = 64 elements per thread.
In raw CUDA you'd write: int row = threadIdx.x / 16; int col = threadIdx.x % 16; and manually compute which 8ร8 sub-tile this thread owns. CuTe replaces all of that with layout algebra โ composable, type-safe, zero-cost abstractions that work for any MMA instruction, not just scalar FMA.
๐ Glossary: Abbreviations โ Full Names
Abbreviation
Full Name
What It Is
MMA
Matrix Multiply-Accumulate
Hardware instruction: D = A ร B + C
FMA
Fused Multiply-Add
Scalar MMA: d = a ร b + c
CTA
Cooperative Thread Array
= thread block in CUDA
TV
Thread-Value
Layout mapping (thread_idx, value_idx) โ element offset
Thr
Thread
Thread dimension in a layout
Frg / Frag
Fragment
Per-thread owned data elements
V / Val
Value
Elements within one atom execution
Perm
Permutation
Reordering of element positions
GMEM
Global Memory
GPU DRAM (HBM)
SMEM
Shared Memory
On-chip per-CTA scratchpad
RMEM
Register Memory
Per-thread register file
BM/BN/BK
Block M/N/K
CTA tile dimensions
TiledMma
Tiled MMA
MMA atom replicated across threads + positions
thrfrg_C
Thread-Fragment C
Algorithm: partition C into thread and fragment dims
๐ฃ Core Principle: Layout Composition as Function Composition
The Key Idea
A CuTe Layout = Shape + Stride is a function from logical coordinates to memory offsets.
partition_C needs to answer: "Given that thread 0 owns M-positions {0,1,2,3,64,65,66,67} and N-positions {0,1,2,3,64,65,66,67}, what is the layout that maps a flat index 0..63 to the correct memory offsets?"
Instead of computing 64 offsets explicitly, CuTe composes layout functions. The output is a new layout (pure metadata, no data movement) whose strides encode the answer.
Three Primitive Operations
logical_divide(A, B)
Splits layout A using tiler B.
Math: A โ (B, B*) where B* = complement of B.
Shape transforms: (M) โ (TileM, RestM)
Like reshape(128) โ (64, 2) but respecting the stride structure
zipped_divide(A, B)
= logical_divide + zip: groups all tile dims together and all rest dims together.
Atom: 1ร1ร1 scalar FMA
AtomLayoutC: (1, 1):(0, 0) โ TV layout
1 thread, 1 value per atom
atoms_layout: (16, 16, 1):(16, 1, 0)
256 threads as 16ร16 MรN grid
permutation_M: (16, 4):(4, 1)
permutation_N: (16, 4):(4, 1)
product = 16ร4 = 64 positions per dim
128/64 = 2 repetitions needed
What does permutation (16, 4):(4, 1) mean?
It maps a 2D coordinate (f, r) where fโ[0,16), rโ[0,4) to position: f*4 + r*1
For thread 0 (f=0): positions are {0, 1, 2, 3} โ the first 4 consecutive elements.
For thread 1 (f=1): positions are {4, 5, 6, 7} โ the next 4 elements.
This covers 64 positions total (0..63). The tile is 128, so there's a "Rest" factor of 2 for positions {64..127}.
๐ The 5-Step Pipeline
INPUT
(128,128)
(128,1)
โ
โ logical_divide
Split by perm
โ
โก zipped_divide
Group atom+rest
โ
โข compose
MN โ TV
โ
โฃ zipped_divide
Group thr+frg
โ
โค slice
Pick thread 0
โ
OUTPUT
(1,(4,2),(4,2))
(0,(128,8192),(1,64))
โ logical_divide โ Split Layout by Permutation
C++ Source (mma_atom.hpp:256-258)
auto t_tile = make_tile(permutation_mnk<0>(), // permutation_M
permutation_mnk<1>()); // permutation_N
auto t_tensor = logical_divide(ctensor, t_tile); // (PermM, PermN)
What It Does
Splits each dimension of gC according to the permutation layout. The permutation (16,4):(4,1) covers 64 of the 128 positions, leaving a "rest" factor of 2.
โก zipped_divide โ Group Atom Dimensions Together
C++ Source (mma_atom.hpp:261-263)
auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), // AtomM = 1
make_layout(size<1>(AtomShape_MNK{}))); // AtomN = 1
auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN))
What It Does
Groups "atom" dimensions (the elements consumed by one MMA instruction) into one mode, and "rest" dimensions (how many atom invocations are needed) into another.
Why?
Step 3 needs to apply the TV (Thread-Value) layout within the atom. By zipping atom dimensions together, the atom mode becomes a self-contained unit we can transform independently.
Derivation
c_tile = (1, 1) โ atom covers 1ร1 elements (scalar FMA)
zipped_divide splits by (1, 1), which is trivial:
Each (16,4,2) subgroup in M: atom takes 1, rest gets (16,4,2)
Each (16,4,2) subgroup in N: atom takes 1, rest gets (16,4,2)
Then zip: atom parts โ mode 0, rest parts โ mode 1
Modes:
Mode 0 = Atom (1,1): one MMA invocation processes 1 element
Mode 1 = Rest (16,4,2,16,4,2): 16384 total atom invocations to cover the tile
โ Note:Atom strides are "don't care" since atom is size 1ร1.
For a tensor core atom (e.g., 16ร8ร8 HMMA), the atom mode would be (16, 8) with meaningful strides, and the rest would be smaller. The structure of the algorithm stays the same โ only the numbers change. This is the power of CuTe's generality.
โข compose โ Transform Atom from (M,N) to (Thread, Value)
C++ Source (mma_atom.hpp:265-266)
// Transform the Atom mode from (M,N) to (Thr,Val)
auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{}, _); // ((ThrV,FrgV),(RestM,RestN))
What It Does
Applies the atom's TV (Thread-Value) layout to mode 0 (the atom). This re-labels atom coordinates from spatial (M, N) to functional (which thread, which value within that thread).
Why This Step Exists
The atom layout knows how a hardware MMA instruction distributes its inputs/outputs across threads. For example, a tensor core MMA might spread a 16ร8 output across 32 threads, each holding 4 values. The TV layout encodes that hardware-specific mapping.
For our scalar FMA: 1 thread computes 1 value, so TV = (1,1):(0,0) โ trivial.
Derivation
AtomLayoutC_TV = (1, 1):(0, 0)
โ 1 thread (ThrV), 1 value per atom (FrgV = Fragment Value)
compose(atom_mode=(1,1), TV=(1,1)) โ (ThrV=1, FrgV=1)
This is just relabeling โ no numeric change.
The "_" in compose(..., _) means: leave mode 1 (Rest) untouched.
Key:ThrV = threads within atom (intra-atom thread index) FrgV = values per thread within atom (fragment values)
For a 16ร8 tensor core (HMMA): AtomLayoutC_TV would be e.g., (32, 4):(some_strides) โ 32 threads each holding 4 values. Then compose would transform (16,8) โ (32, 4) with non-trivial strides derived from the hardware's register layout. This is where CuTe encodes hardware knowledge.
โฃ zipped_divide โ Separate Thread-Indexing from Thread-Owned
C++ Source (mma_atom.hpp:268-272)
// Tile the tensor for the C-threads
auto thr_tile = make_tile(
_, // keep ThrV as-is
make_tile(make_layout(size<1>(thr_layout_vmnk_)), // ThrM = 16
make_layout(size<2>(thr_layout_vmnk_)))); // ThrN = 16
auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))
What It Does
The "Rest" mode from step 3 contains both thread-indexing dimensions (which of the 16ร16 threads) and per-thread-owned dimensions (the 4ร2 elements each thread computes). This step separates them.
Why?
After this step, mode 0 = "which thread am I?" and mode 1 = "what data do I own?". The final slice (step 5) uses mode 0 to pick a thread and returns mode 1 as the result.
Derivation
thr_tile for mode 1 (Rest) = (ThrM=16, ThrN=16)
Mode 1 was: (16, 4, 2, 16, 4, 2):(512, 128, 8192, 4, 1, 64)
M-part N-part
Divide M-rest (16, 4, 2) by ThrM=16:
โข Tile: 16 positions (strides from F dimension) โ stride 512
โข Rest: (4, 2) โ strides (128, 8192)
Divide N-rest (16, 4, 2) by ThrN=16:
โข Tile: 16 positions โ stride 4
โข Rest: (4, 2) โ strides (1, 64)
Zip โ thread parts together, fragment parts together:
Input:((ThrV=1, FrgV=1), (16,4,2, 16,4,2))
Output:
((1, (16, 16)), (1, (4,2), (4,2)))
Strides:
((0, (512, 4)), (0, (128,8192), (1,64)))
Meaning:
Mode 0 = Thread selector: (ThrV=1, (ThrM=16, ThrN=16)) โ picks one of 256 threads
Mode 1 = Fragment: (FrgV=1, (RestM=(4,2)), (RestN=(4,2))) โ 64 elements this thread owns
// In partition_C:
auto thr_vmn = make_coord(get<0>(thr_vmnk_), // thr_v = 0
make_coord(get<1>(thr_vmnk_), // thr_m
get<2>(thr_vmnk_))); // thr_n
return thr_tensor(thr_vmn, make_coord(_, repeat<...>(_))); // โ fragment view
What It Does
Plugs in the concrete thread coordinates into mode 0 (the thread selector). This computes a pointer offset for this thread and returns mode 1 (the fragment) as the result layout.
Each of 256 threads owns a 4ร4 block at position (tmร4, tnร4) and a mirror at (+64, +64). Four 4ร4 blocks = 64 elements. Let's visualize thread ownership for the first 16ร16 corner of the 128ร128 C tile:
Each colored 4ร4 block is one thread's first sub-block. Thread numbers shown. The second sub-block at (+64,+64) follows the same pattern.
Element Count Check
Thread 0 owns: 1 ร (4ร2) ร (4ร2) = 1 ร 8 ร 8 = 64 elements
256 threads ร 64 = 16,384 = 128 ร 128 โ
No element is claimed by two threads (permutation is a bijection).
No element is unclaimed (complement covers the remainder).
๐ฏ Summary: Why Each Step Exists
Step
Operation
Problem It Solves
Output Shape
โ
logical_divide
Factorize the tile according to the thread permutation โ separate "which group" from "position within group"
((16,4,2),(16,4,2))
โก
zipped_divide
Isolate the atom's footprint (what one MMA instruction touches) from the rest (how many times we repeat)
((1,1),(16,4,2,16,4,2))
โข
compose
Re-label atom coordinates from spatial (M,N) to functional (Thread, Value) using hardware knowledge
((ThrV,FrgV),(Rest...))
โฃ
zipped_divide
Separate thread-indexing dimensions from per-thread fragment dimensions across the "rest" mode
((Thr),(Frg))
โค
slice
Plug in a concrete thread ID and return just that thread's fragment layout
(1,(4,2),(4,2))
The Design Principle
CuTe decomposes the "which thread owns which elements" problem into composable algebraic primitives โ divide, zip, compose, slice โ each handling one concern:
Spatial structure (how elements are arranged in memory)
Hardware structure (how the MMA instruction maps threads to elements)
Tiling structure (how many atoms tile across the output)
By keeping these orthogonal, the same code works for a scalar FMA atom, a 16ร8ร8 tensor core, or a 64ร256ร32 Blackwell UMMA โ only the atom's TV layout changes. The algorithm is invariant to the hardware.