Sharded Matmul Simulator
Configure A[I, J] · B[J, K] → C[I, K] and watch the mesh compute it: which collectives run, how shards move, and where the time goes.
Presets:
Case 3 reduction
Case 3: both inputs sharded along the contracting dimension
Each device can multiply its matching J-chunks, but the result is only a partial sum, written C {U_axis}. An AllReduce (or a cheaper ReduceScatter, if a sharded output is fine) completes the sum.
A[I, J_X] · B[J_X, K] → C[I, K]
1 / 4
layoutInitial layout: A[I, J_X] · B[J_X, K] on a 2x2 mesh.
Y=0Y=1
X=0X=1
TPU 0 (0,0)
?
CTPU 1 (0,1)
?
CTPU 2 (1,0)
?
CTPU 3 (1,1)
?
CBlocks are coloured by data identity; striped blocks are partial sums awaiting a reduction. Arriving blocks fly in from the device that sent them.
Cost model (TPU v5e, bf16)
| Operation | Bytes | Time |
|---|---|---|
| AllReduceX (C) | 8.39 MB | 186 µs |
| Local matmul / device | 8.59 GFLOPs | 43.6 µs |
Comms total: 186 µsCompute total: 43.6 µsCommunication bound