How to Scale Your Model · Part 3

Sharded Matrices and How to Multiply Them

When a model no longer fits on one chip, its arrays get sharded across a mesh of devices — and multiplying sharded matrices means moving data. These interactive visualisations let you change the mesh, the matrix dimensions and the sharding strategies, and watch exactly what communication each choice costs.

1

Sharding Notation Explorer

Build the vocabulary: meshes, mesh axes, and shardings like A[I_X, J_Y]. See which block of the matrix each device holds, local shapes, and replication.

Open →
2

Sharded Matmul Simulator

The centrepiece: configure the shardings of A[I, J] · B[J, K], get the case classification (1-4), and watch the mesh AllGather, multiply, and reduce its way to C — with the full cost breakdown.

Open →
3

Collectives Playground

AllGather, ReduceScatter, AllReduce and AllToAll animated hop by hop on a ring, with the runtime table explaining the 2x and 1/4x cost relationships.

Open →

The four cases, in one table

CaseSituationWhat to do
1Neither input sharded along the contracting dimensionMultiply local shards. No communication.
2One input has a sharded contracting dimensionAllGather the sharded input along J first.
3Both inputs sharded along the contracting dimensionMultiply local shards, then AllReduce (or ReduceScatter).
4Both inputs share a mesh axis on non-contracting dimensionsInvalid as-is: AllGather one input first.