Expert Parallelism

distributed-training
Published

September 1, 2025

Let’s distill how to run a Mixture-of-Experts (MoE) model with expert parallelism with example.

The setup is 8 GPUs using a 2-D device mesh with Expert Parallel (EP) and FSDP2 (fully_shard).

This short post explains who stores which weights, where all-to-all happens, and includes minimal code (meshes and wrapping).


TL;DR

  • Columns move tokens. Rows share weights.
  • Cols (dp_shard_in_ep) = EP axis → which experts live where + token all-to-all.
  • Rows (dp_shard_mod_ep) = FSDP axis → how owned experts are sharded.
  • Non-MoE: FSDP(8) across all ranks (no persistent replication).
  • Experts: EP(2) across columns (ownership) × FSDP(4) across rows (shards).
  • EP all-to-all is row-local pairs: (0↔︎1), (2↔︎3), (4↔︎5), (6↔︎7). Microbatches can differ (variable-size A2A handles it).

Device Mesh & Rank Layout

Arrange 8 GPUs as a 4×2 grid:

(rows, cols) → rank
(0,0)→0   (0,1)→1
(1,0)→2   (1,1)→3
(2,0)→4   (2,1)→5
(3,0)→6   (3,1)→7
  • Rows (dp_shard_mod_ep) size 4 → FSDP sharding axis for experts.
  • Cols (dp_shard_in_ep) size 2 → EP ownership + row-local all-to-all.
  • Flattened dp view (rows×cols) size 8 → dataloader + FSDP for non-MoE.

Here is the pytorch code.

dp_shard_in_ep = ep_size                # borrowed by EP (forms EP groups)
dp_shard_mod_ep = dp_shard // ep_size   # leftover for FSDP sharding. borrowing!

# Create the 2D mesh
device_mesh = init_device_mesh(
    "cuda",
    (dp_shard_mod_ep, dp_shard_in_ep),
    mesh_dim_names=("dp_shard", "ep"),
)

# Create aliases for DP
# DP will be used for data loading
device_mesh[("dp_shard", "ep")]._flatten(mesh_dim_name="dp")

dp_mesh  = device_mesh["dp"]     # size = R*C = 8  -> used for non-MoE FSDP
row_mesh = device_mesh["dp_shard"]   # size = R   = 4  -> used for expert FSDP (inside each column)
col_mesh = device_mesh["ep"]   # size = C   = 2  -> used for EP ownership + a2a

Who Stores Which Expert?

Assume 8 experts E0..E7. With EP=2:

  • Column 0 owns experts E0–E3 (no copy of E4–E7).
  • Column 1 owns experts E4–E7 (no copy of E0–E3).
  • Inside each column, FSDP(4) shards the owned experts across the 4 rows.
  • This makes expert-to-expert communication possible. Rank 0 has E0–E3 and Rank 1 has E4-E7.
Rank Column owns This rank holds (expert weights)
0 = (0,0) E0–E3 1/4 shard of E0–E3
2 = (1,0) E0–E3 1/4 shard of E0–E3
4 = (2,0) E0–E3 1/4 shard of E0–E3
6 = (3,0) E0–E3 1/4 shard of E0–E3
1 = (0,1) E4–E7 1/4 shard of E4–E7
3 = (1,1) E4–E7 1/4 shard of E4–E7
5 = (2,1) E4–E7 1/4 shard of E4–E7
7 = (3,1) E4–E7 1/4 shard of E4–E7

Non-MoE (embeddings, attention, MLP, norms) are FSDP-sharded 8-way across all ranks (flattened dp). There is no persistent replication.


Where Does All-to-All Happen?

EP all-to-all is row-local:

  • Row 0: GPU0 ⇄ GPU1
  • Row 1: GPU2 ⇄ GPU3
  • Row 2: GPU4 ⇄ GPU5
  • Row 3: GPU6 ⇄ GPU7

Each rank runs the router on its own microbatch. Tokens are split by destination column (owner of the chosen expert), then a size-exchange + variable-size all-to-all exchanges exactly those slices between the two columns in the same row.

After dispatch, each rank holds only tokens for experts it owns. Expert matmuls then use FSDP across the 4 rows in that column (all-gathercomputereduce-scatter). Finally, the inverse all-to-all returns outputs to the source ranks.

Let’s illustrate the token dispatch operation with a well-annotated code snippet.

def _token_dispatch(self, model, inputs, device_mesh):
    """
        All-to-all communication
        input_splits is different coming from each device (assuming some data parallelism)
    """
    ep_size = device_mesh.shape[0]
    x_gathered, num_tokens_per_expert = inputs
    num_tokens_per_expert_group = num_tokens_per_expert.new_empty(
        num_tokens_per_expert.shape[0]
    )

    # distributed transpose operation.
    # 0th GPU gets all 0th row

    # Preliminary all-to-all to exchange token counts. This is used to
    # calculate the split sizes for the main token all-to-all dispatch.
    #
    # Before (on GPU 0):
    #   `num_tokens_per_expert`: [10, 5, 12, 8, 11, 6, 13, 7]
    #   (Counts of local tokens for all 8 global experts)
    #
    # After (on GPU 0, which hosts experts 0 and 1):
    #   `num_tokens_per_expert_group` is filled with:
    #   [10, 5, | 9, 4, | 14, 2, | 3, 11]
    #   (Counts for my local experts [E0,E1] from GPU0, GPU1, GPU2, GPU3)
    
    dist.all_to_all_single(
        num_tokens_per_expert_group, # output!
        num_tokens_per_expert, # input
        group=device_mesh.get_group(),
    )


    input_splits = num_tokens_per_expert.view(
        ep_size, -1
    ).sum(dim=1).to(torch.device("cpu"))

    output_splits = num_tokens_per_expert_group.view(
        ep_size, -1
    ).sum(dim=1).to(torch.device("cpu"))

    self.input_splits = input_splits.tolist()
    self.output_splits = output_splits.tolist()

    # this is an uneven communication (e.g. ragged), where each GPU receives an uneven amount of tokens.

    # On GPU 0:
    # - Total tokens before send (sum of num_tokens_per_expert): 72
    # - input_splits (how to slice the 72 tokens for sending): [15, 20, 17, 20]
    # - output_splits (how many tokens to expect from each GPU): [15, 13, 16, 14]

    # Before all_to_all, each GPU has a different number of tokens and a different plan:
    # GPU 0: tensor of size 72, sends chunks of [15, 20, 17, 20]
    # GPU 1: (example) tensor of size 80, sends chunks of [13, 25, 22, 20]
    # GPU 2: (example) tensor of size 75, sends chunks of [16, 18, 21, 20]
    # GPU 3: (example) tensor of size 68, sends chunks of [14, 15, 19, 20]

    # After all_to_all on GPU 0:
    # - Receives: 15 from GPU0, 13 from GPU1, 16 from GPU2, 14 from GPU3
    # - Output tensor size = sum(output_splits) = 15 + 13 + 16 + 14 = 58
    # - This new tensor of 58 tokens contains data for GPU 0's local experts (E0, E1),
    #   but is grouped by source GPU, not by expert ID. It needs a local shuffle.

    # all_to_all_single_autograd allows differentiable data transfer
    print(f"{self.output_splits=} {self.input_splits=}")

    x_gathered = all_to_all_single_autograd(
        x_gathered,
        self.output_splits,
        self.input_splits,
        device_mesh.get_group(),
    )

    # num_tokens_per_expert_group
    #   [10, 5, | 9, 4, | 14, 2, | 3, 11]
    # 
    #   x_gathered on GPU 0 (shape: [58, h])
    #  +------------------------------------------------+
    #  |                                                |
    #  |  Block of 15 tokens RECEIVED from GPU 0        |
    #  |  (Contains 10 tokens for MY E0, 5 for MY E1)   |
    #  |                                                |
    #  +------------------------------------------------+  <-- Boundary at index 14
    #  |                                                |
    #  |  Block of 13 tokens RECEIVED from GPU 1        |
    #  |  (Contains 9 tokens for MY E0, 4 for MY E1)    |
    #  |                                                |
    #  +------------------------------------------------+  <-- Boundary at index 27 (14+13)
    #  |                                                |
    #  |  Block of 16 tokens RECEIVED from GPU 2        |
    #  |  (Contains 14 tokens for MY E0, 2 for MY E1)   |
    #  |                                                |
    #  +------------------------------------------------+  <-- Boundary at index 43 (27+16)
    #  |                                                |
    #  |  Block of 14 tokens RECEIVED from GPU 3        |
    #  |  (Contains 3 tokens for MY E0, 11 for MY E1)   |
    #  |                                                |
    #  +------------------------------------------------+  <-- Final boundary at index 57

    #   Target layout for x_gathered (shape: [58, h])
    #  +------------------------------------------------+
    #  |                                                |
    #  |  All 36 tokens for MY Expert 0                 |
    #  |  (Gathered from the 4 blocks above)            |
    #  |                                                |
    #  +------------------------------------------------+  <-- Boundary at index 35
    #  |                                                |
    #  |  All 22 tokens for MY Expert 1                 |
    #  |  (Gathered from the 4 blocks above)            |
    #  |                                                |
    #  +------------------------------------------------+ 

    # target for num_tokens_per_expert_group
    #    [36, 22]


    # Reshape to see GPU-expert structure
    tokens = num_tokens_per_expert_group.view(-1, ep_size)  
    # Shape: [4, 2] where dim0=GPU, dim1=expert
    # [[10,  5],  <- GPU 0: 10 tokens for E0, 5 for E1
    #  [ 9,  4],  <- GPU 1: 9 tokens for E0, 4 for E1
    #  [14,  2],  <- GPU 2: 14 tokens for E0, 2 for E1
    #  [ 3, 11]]  <- GPU 3: 3 tokens for E0, 11 for E1
    expert_per_device = num_tokens_per_expert_group.shape[0] // ep_size
    expert_ids = torch.repeat_interleave(
        torch.arange(expert_per_device).repeat(ep_size).to('cuda'),  # [0, 1, 0, 1, 0, 1, 0, 1] - expert pattern
        num_tokens_per_expert_group  # [10,5,9,4,14,2,3,11] - repeat counts
    )
    
    # index looks like
    # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 15, 16, 17, 18, 19, 20, 21, 22,
    # 23, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 44, 45, 46,
    # 10, 11, 12, 13, 14, 24, 25, 26, 27, 42, 43, 47, 48, 49, 50, 51, 52, 53,
    # 54, 55, 56, 57])
    self.index = torch.argsort(expert_ids, stable=True)
    x_reorganized = x_gathered[self.index, :]

    # per expert aggregation
    num_tokens_per_expert_group_agg = tokens.sum(dim=1)

    return x_reorganized, num_tokens_per_expert_group_agg

The Three “Boxes” in the Forward

  1. Non-experts: FSDP(8) across all ranks

    all_gather(8) → compute → reduce_scatter(8)
  2. EP dispatch/return: row-local all-to-all

    Pairs: (0↔1), (2↔3), (4↔5), (6↔7). Size-exchange then A2A.
  3. Experts: FSDP(4) inside each column

    all_gather(4) → grouped GEMM → reduce_scatter(4)

For visually initiated, here is how the placement looks:

🎮 View Interactive 3D Visualization - See how data flows through the Non-MoE, MoE FSDP, and Token Routing layers.


Minimal Mesh & Wrapping


# 0) EP: assign experts to columns + install dispatch/combine hooks
class ExpertParallel(ParallelStyle):
    def __init__(self):
        super().__init__()
        self.input_splits = None
        self.output_splits = None

    # performing all-to-all dispatch on the input
    def _token_dispatch(self, mod, inputs, device_mesh):
        ...

    @staticmethod
    def _partition_fn(name, mod, device_mesh):
        # shard on the expert dimension
        for name, param in mod.named_parameters(recurse=False):
            dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
            mod.register_parameter(name, dist_param)

    # performing all-to-all combine on the output
    def _token_combine(self, mod, routed_output, device_mesh):
        routed_output = all_to_all_single_autograd(
            routed_output,
            self.input_splits,
            self.output_splits,
            device_mesh.get_group(),
        )
        return routed_output

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        return distribute_module(
            module,
            device_mesh,
            partition_fn=ExpertParallel._partition_fn,
            input_fn=self._token_dispatch,
            output_fn=self._token_combine,
        )


parallelize_module(
    model, device_mesh=col_mesh,
    parallelize_plan={"layers.*.moe.experts": ExpertParallel()},
)

# 1) FSDP on experts (rows): shard column-local experts across rows
for tb in model.transformer_blocks:
    fully_shard(tb.moe.experts, mesh=row_mesh, reshard_after_forward=False)

# 2) FSDP on each block (dp): shard non-MoE 8-way
for tb in model.transformer_blocks:
    fully_shard(tb, mesh=dp_mesh, reshard_after_forward=False)   # experts already DTensors on rows

# 3) FSDP on root (dp): embeddings / head / leftovers
fully_shard(model, mesh=dp_mesh, reshard_after_forward=True)     # True/False per memory tradeoff

Pseudo-Forward

def forward_on_rank(row, col, x_rc):
    # A) Non-MoE on FSDP: AG(8) → compute → RS(8)
    h = non_moE_stack(x_rc)

    # Router: expert id per token (e.g., [2,2,5,1,6,...])
    eids = router(h)

    # B) EP dispatch within this row: send tokens to owner column
    h_owned = ep_a2a_dispatch(h, eids, group={(row,0),(row,1)})

    # C) Experts on owner column: FSDP(4) across rows of this column
    y_local = experts_matmul(h_owned)           # AG(4) → GEMM → RS(4)

    # D) EP combine within this row: inverse A2A back to source rank
    y = ep_a2a_combine(y_local, eids, group={(row,0),(row,1)})

    # E) Tail Non-MoE on dp: AG(8) → compute → RS(8)
    out = non_moE_tail(y)
    return out

Backward (intuition): EP’s A2A autograd returns activation grads to sources; expert parameter grads are reduce-scattered across the row group (the 4 that shard that expert). Non-MoE grads reduce-scatter across all 8.


After Parallelization: Expected Placements

def show_placements(model, keys=("moe.experts","attention","router","embeddings","norm","output")):
    for name, p in model.named_parameters():
        if any(k in name for k in keys):
            pl = getattr(p, "placements", None)
            kind = "DTensor" if pl is not None else "LOCAL"
            print(f"{name:<60} -> {kind:7} {pl}")

Example lines you should see:

tok_embeddings.weight                                   -> DTensor (Shard(dim=0),)
layers.0.attention.wq.weight                            -> DTensor (Shard(dim=0),)
layers.0.feed_forward.router.router.weight              -> DTensor (Shard(dim=0),)
layers.0.feed_forward.experts.w1                        -> DTensor (_StridedShard(dim=0, sf=2), Shard(dim=0))
layers.0.feed_forward.experts.w2                        -> DTensor (_StridedShard(dim=0, sf=2), Shard(dim=0))
layers.0.feed_forward.experts.w3                        -> DTensor (_StridedShard(dim=0, sf=2), Shard(dim=0))
...

Legend

  • (Shard(dim=0),) → Non-MoE params sharded on a 1-D mesh (flattened dp → 8-way).
  • (_StridedShard(dim=0, sf=2), Shard(dim=0)) → Experts are split on two separete mesh axes of the same tensor dim-0 (experts):
    • Shard(dim=0) on cols (EP ownership).
    • _StridedShard(dim=0, sf=2) on rows (FSDP across 4 rows, after a prior split by 2 cols). sf=2 because dim-0 was already split once by the 2 columns

Appendix: Dataloader on Flattened DP.

# Use the flattened 8-way dp mesh for sampling
global_dp_rank = dp_mesh.get_rank()
num_replicas   = dp_mesh.size()

from torch.utils.data import DistributedSampler, DataLoader
sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=global_dp_rank)
loader  = DataLoader(dataset, batch_size=per_rank_bsz, sampler=sampler, pin_memory=True)

This preserves 8 microbatches per step—one per rank—while placements and collectives follow the mesh rules above.