Sharded Matmul Simulator

Configure A[I, J] · B[J, K] → C[I, K] and watch the mesh compute it: which collectives run, how shards move, and where the time goes.

Presets:
Case 3 reduction
Case 3: both inputs sharded along the contracting dimension

Each device can multiply its matching J-chunks, but the result is only a partial sum, written C {U_axis}. An AllReduce (or a cheaper ReduceScatter, if a sharded output is fine) completes the sum.

A[I, J_X] · B[J_X, K] → C[I, K]

1 / 4

layoutInitial layout: A[I, J_X] · B[J_X, K] on a 2x2 mesh.

Y=0Y=1
X=0X=1
TPU 0 (0,0)
A
·
B
?
C
TPU 1 (0,1)
A
·
B
?
C
TPU 2 (1,0)
A
·
B
?
C
TPU 3 (1,1)
A
·
B
?
C

Blocks are coloured by data identity; striped blocks are partial sums awaiting a reduction. Arriving blocks fly in from the device that sent them.

Cost model (TPU v5e, bf16)

OperationBytesTime
AllReduceX (C)8.39 MB186 µs
Local matmul / device8.59 GFLOPs43.6 µs
Comms total: 186 µsCompute total: 43.6 µsCommunication bound