How to Scale Your Model · Part 3
Sharded Matrices and How to Multiply Them
When a model no longer fits on one chip, its arrays get sharded across a mesh of devices — and multiplying sharded matrices means moving data. These interactive visualisations let you change the mesh, the matrix dimensions and the sharding strategies, and watch exactly what communication each choice costs.
Sharding Notation Explorer
Build the vocabulary: meshes, mesh axes, and shardings like A[I_X, J_Y]. See which block of the matrix each device holds, local shapes, and replication.
Open →2Sharded Matmul Simulator
The centrepiece: configure the shardings of A[I, J] · B[J, K], get the case classification (1-4), and watch the mesh AllGather, multiply, and reduce its way to C — with the full cost breakdown.
Open →3Collectives Playground
AllGather, ReduceScatter, AllReduce and AllToAll animated hop by hop on a ring, with the runtime table explaining the 2x and 1/4x cost relationships.
Open →The four cases, in one table
| Case | Situation | What to do |
|---|---|---|
| 1 | Neither input sharded along the contracting dimension | Multiply local shards. No communication. |
| 2 | One input has a sharded contracting dimension | AllGather the sharded input along J first. |
| 3 | Both inputs sharded along the contracting dimension | Multiply local shards, then AllReduce (or ReduceScatter). |
| 4 | Both inputs share a mesh axis on non-contracting dimensions | Invalid as-is: AllGather one input first. |