Building 10T tokens/month throughput infrastructure for MoE training

A worklog documenting the journey of scaling expert parallelism to achieve high-throughput pretraining

For reference: Qwen3 30B A3B was trained on 36T tokens

Technical worklogs documenting the development of MoE pretraining infrastructure, optimization experiments, and scaling challenges.

Table of Contents

Worklogs

November 14 - Scaling Expert Parallelism Linearly

Fixed critical issues to achieve near-linear scaling of expert parallelism across nodes.

Problem 1: Intranode kernels scale poorly with expert parallelism

Intranode kernels (especially cached_notify_combine) scale poorly with expert parallelism (EP). I mitigated this by tuning the number of SMs (num_sms) allocated to DeepEP.

Using the Nsight report for EP=2 and EP=4, the top-15 slowest kernels show that the DeepEP intranode kernels dominate the GPU time. The worst is: deep_ep::intranode::cached_notify_combine(int)

EP=4 – Top 15 Kernels

void deep_ep::intranode::cached_notify_combine<(int)4> void **, int *,     54.306s  ( 38.1%)
ncc!DevKernel_AllGather_RING_LL`ncc!DevKernelArgsStorage< unsigned lon     17.169s  ( 12.1%)
void deep_ep::intranode::dispatch<(int)4, (int)768, (int)8192> int4 *,      9.107s  (  6.4%)
void deep_ep::intranode::combine<__nv_bfloat16, (int)4, (int)768, (int      8.539s  (  6.0%)
void at::native::indexFuncLargeIndex<c10::BFloat16, long, unsigned int      4.289s  (  3.0%)
void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm1      3.930s  (  2.8%)
void at::native::_scatter_gather_elementwise_kernel<(int)128, (int)8,       3.583s  (  2.5%)
cudnn_generated_fort_native_sdpa_sm100_flash_bprop_f16_knob_3i_128x128       2.811s  (  2.0%)
ncc!DevKernel_ReduceScatter_Sum_T32_RING_LL`ncc!DevKernelArgsStorage<       2.559s  (  1.8%)
void at::native::elementwise_kernel<(int)128, (int)4, void at::native:      2.504s  (  1.8%)
void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unn      2.246s  (  1.6%)
void deep_ep::intranode::cached_notify_dispatch<(int)4> const int *, i      2.226s  (  1.6%)
void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm1      2.045s  (  1.4%)
void at::native::<unnamed>::vectorized_layer_norm_kernel<c10::BFloat16       1.943s  (  1.4%)
void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm1      1.940s  (  1.4%)

EP=2 – Top 15 Kernels

ncc!DevKernel_AllGather_RING_LL`ncc!DevKernelArgsStorage< unsigned lon     31.187s  ( 30.5%)
void deep_ep::intranode::cached_notify_combine<(int)2> void **, int *,     19.856s  ( 19.4%)
ncc!DevKernel_ReduceScatter_Sum_T32_RING_LL`ncc!DevKernelArgsStorage<       6.529s  (  6.4%)
void deep_ep::intranode::combine<__nv_bfloat16, (int)2, (int)768, (int      5.060s  (  5.0%)
void deep_ep::intranode::dispatch<(int)2, (int)768, (int)8192> int4 *,      3.468s  (  3.4%)
cudnn_generated_fort_native_sdpa_sm100_flash_bprop_f16_knob_3i_128x128       2.406s  (  2.4%)
void at::native::elementwise_kernel<(int)128, (int)4, void at::native:      2.064s  (  2.0%)
void at::native::_scatter_gather_elementwise_kernel<(int)128, (int)8,       1.843s  (  1.8%)
void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unn      1.776s  (  1.7%)
void at::native::indexFuncLargeIndex<c10::BFloat16, long, unsigned int      1.742s  (  1.7%)
void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm1      1.696s  (  1.7%)
void at::native::<unnamed>::vectorized_layer_norm_kernel<c10::BFloat16       1.681s  (  1.6%)
cudnn_generated_fort_native_sdpa_sm100_flash_fprop_f16_knob_7_128x128x       1.602s  (  1.6%)
void at::native::detail::chunk_cat_cuda_kernel<float, c10::BFloat16>::T      1.441s  (  1.4%)
void deep_ep::intranode::cached_notify_dispatch<(int)2> const int *, i      1.397s  (  1.4%)

EP = 4: cached_notify_combine = 54.306 s (≈ 38.1% of GPU time)
EP = 2: cached_notify_combine = 19.856 s (≈ 19.4% of GPU time)
Result: 2.73× slowdown in that kernel when doubling expert parallelism.

Other DeepEP intranode kernels also scale poorly:

dispatch<int4>: 9.107 s vs dispatch<int2>: 3.468 s2.63× slower
combine<int4>: 8.539 s vs combine<int2>: 5.060 s1.69× slower

At the system level: EP=4 is 42.7% slower than EP=2 (4624 vs 6599 tokens/sec).

From csrc/kernels/intranode.cu:613-628, the kernel is launched with 1 + num_channels blocks, each block processes all ranks, assigning one warp per rank, so each block does 2× more work instead of adding more parallelism.

This explains the ~2.7× slowdown in cached_notify_combine and the general degradation of DeepEP intranode kernels at higher EP.

After fixing Problem 2 (see below) to make num_sms tunable, I swept over multiple SM counts and found a significantly better configuration:

num_sms = 128 (up from 24)

Dispatch Config:

- turbo_deepep_num_cus = 128
- turbo_deepep_dispatch_tuned_config = (32, 1024, 8, 128)
- Performance: 122.32 μs, 496.28 GB/s

Combine Config:

- turbo_deepep_combine_tuned_config = (16, 256, 8, 128)
- Performance: 127.43 μs, 476.36 GB/s

Performance Improvements

Comparison vs Current Baseline (num_sms=24):

- Dispatch: 56.3% faster (279.69 μs → 122.32 μs)
- Combine: 61.8% faster (333.95 μs → 127.43 μs)
- Bandwidth: 2.28x higher for dispatch (217.90 GB/s → 496.28 GB/s)
- Bandwidth: 2.61x higher for combine (182.49 GB/s → 476.36 GB/s)

Comparison vs Worst (num_sms=8):

- Dispatch: 83.0% faster (721.57 μs → 122.32 μs)
- Combine: 84.0% faster (794.58 μs → 127.43 μs)

Problem 2: Make num_sms tunable in DeepEP's benchmarking script

While trying to tune num_sms, I discovered that DeepEP's intranode benchmarking/tuning code implicitly assumes a single fixed num_sms. Any attempt to change it mid-run would assert.

From csrc/config.hpp:61:

const int num_channels = num_sms / 2;  // KEY: derived from num_sms

This breaks when we vary num_sms:

• Initial run (baseline config): num_sms = 24 → num_channels = 24 / 2 = 12
• Cached matrix shape becomes [4, 12] for a 4-rank setup.
• Later run in the same process, trying to test: num_sms = 32 → expects num_channels = 32 / 2 = 16
• DeepEP checks the cached matrix via an assertion in deep_ep.cpp:403: cached_matrix->size(1) == num_channels
• But the matrix is still [4, 12], so: Expected: 16, Actual: 12

Assertion fails → the tuner crashes as soon as num_sms changes. Because the Buffer's cached routing metadata is intrinsically tied to num_sms, but the code treats it as if it were reusable across configurations.

To unblock tuning, I changed the intranode benchmarking flow to create a separate Buffer instance for each num_sms value in the sweep.

Results: Our current EP runs achieve about 57% of the theoretical limit, while DeepEP's baseline reaches about 34% of the theoretical limit on their hardware.

DeepEP's Reference Squeeze (H800)

- Theoretical: 450 GB/s
- Achieved: 153 GB/s
- Squeezed: 34.0% of hardware capability

Your Squeeze (B200)

- Theoretical: 900 GB/s
- Achieved: 516.71 GB/s
- Squeezed: 57.4% of hardware capability

November 18 - Identifying the Next Bottleneck: CPU Launch Overhead

With near-linear scaling across nodes achieved, focus shifted to optimizing single end-to-end training step performance. Identified the critical bottleneck:

#1. ScatterAddBackward0

CPU Boundary Time:  4.32s (36.26% of total)
GPU Kernel Time:    1.13s (26.16% of CPU time)
CPU Overhead:       3.19s (73.84% of CPU time)

Primary Kernel:

void at::native::indexFuncLargeIndex<c10::BFloat16, long, unsigned int, 2,
  GPU Time: 419.0ms (96 invocations, avg 4364.7μs)

Associated Kernels (12 total):

• void at::native::indexFuncLargeIndex<c10::BFloat16, long, unsigned int, 2, 2, -2, true
• void at::native::_scatter_gather_elementwise_kernel<128, 8, at::native::_cuda_scatter_
• void at::native::vectorized_gather_kernel<16, long>(char*, char*, long*, int, long, lo
• ncc!DevKernel_AllGather_RING_LL(ncc!DevKernelArgsStorage<4096Ul>): 110.8ms (198 calls)
• void at::native::sbtopk::gatherTopK<float, unsigned int, 2, false>(at::cuda::detail::T
• ... and 7 more

#2. FusedDispatch

CPU Boundary Time:  2.12s (17.77% of total)
GPU Kernel Time:  309.3ms (14.61% of CPU time)
CPU Overhead:       1.81s (85.39% of CPU time)

Primary Kernel:

void deep_ep::intranode::dispatch<8, 768, 8192>(int4*, float*, int*, long*,
  GPU Time: 95.2ms (288 invocations, avg 330.4μs)

Associated Kernels (9 total):

• void deep_ep::intranode::dispatch<8, 768, 8192>(int4*, float*, int*, long*, float*, int*
• void at::native::(anonymous namespace)::multi_tensor_apply_kernel<at::native::(anonymous
• void deep_ep::layout::get_dispatch_layout<256, 4, 8>(long const*, int*, int*, int*, bool
• void deep_ep::intranode::notify_dispatch<8>(int const*, int*, int const*, int*, int, int
• void deep_ep::intranode::cached_notify_dispatch<8>(int const*, int, void**, int**, int):
• ... and 4 more

These two ops alone account for 36% + 17% = ~53% of the total end-to-end step time, and most of that is cpu-side dispatch / launch overhead rather than raw gpu compute. (All numbers are per-step cpu boundary time)

Sanity check: I temporarily commented out the single scatter_add line [link]. With everything else unchanged, throughput jumps to ~21k (30% jump in throughput) and we hit ~500 tflop/s - which pretty clearly confirms this as the next real bottleneck.

Summary statistics of all operations

+ Total CPU Boundary Time: 11.91s
+ Total GPU Kernel Time:   31.32s (262.9% of CPU time)
+ Total CPU Overhead:       8.99s (75.4% of CPU time)
+ Total Operations:           153

November 24 - Achieving 10T Tokens/Month Throughput

With 256 GPUs (Qwen3-30B-A3B):

Projected Throughput at 256 GPUs

+ Expected aggregate throughput: 3.6M tokens/sec (14,123 × 256 GPUs)
+ Translates to: 10T tokens/month

+ With scatter optimization (30% improvement):
  - 10T tokens in 20 days
  - 30T tokens in 2 months
  - For reference: Qwen3-30B-A3B was trained on 36T tokens

Scaling Results: Qwen3-30B (128 experts, top-k=8)

Strong Scaling (fixed batch size, increasing nodes):

Configuration Nodes GPUs Tokens/sec TFLOPS Memory/GPU
1 node 1 8 14,796 341 167.93 GiB (94.15%)
2 nodes 2 16 14,380 331 138.75 GiB (77.78%)
4 nodes 4 32 14,276 329 124.58 GiB (69.84%)
8 nodes 8 64 14,107 325 117.50 GiB (65.88%)
16 nodes 16 128 13,856 319 114.78 GiB (64.35%)

Weak Scaling (optimized batch size for 16 nodes):

LBS Nodes GPUs Tokens/sec TFLOPS Memory/GPU
8 16 128 13,856 319 114.78 GiB (64.35%)
10 16 128 14,123 326 142.21 GiB (79.73%)

With expert parallelism optimized to remain within a single node for the 30B-A3B configuration, throughput is expected to scale near-linearly to 256 GPUs. Any throughput degradation at this scale is attributed to non-expert parallelism factors.

November 25 - Fixing Gradient Norm Explosion in MoE Training

Root cause: router weights initialized to zeros.

Investigation started with gradient norm blowup (143,000x imbalance between rank 0 and ranks 1-7). Traced backwards through the routing logic to identify the token distribution issue.

Token Distribution - All tokens routing to rank 0

tokens_per_dest_rank: [524288, 0, 0, 0, 0, 0, 0, 0]
expert_idx: min=0, max=7, mean=3.5

All 8 ranks sending 100% of their tokens to rank 0. With 128 experts across EP=8, each rank should receive somewhere around 524288/128/8 tokens.

Router output analysis:

Router Scores - All zeros

scores.shape: (8192, 128)  # looks right
scores: mean=0, std=0, min=0, max=0  # all zeros

The router computes scores = gate(x) where gate is a linear projection (128, 2048). If scores are all zero, then topk(scores, k=8) has to break ties, PyTorch just returns the first k indices [0,1,2,3,4,5,6,7].

Gate weights and input analysis revealed:

Root Cause - Gate weights initialized to zeros

input x: (8192, 2048), mean=0.005, std=1.0, range [-5, 5.4]  # fine
gate.weight: (128, 2048), mean=0, std=0  # zeros

Input activations look healthy (normalized, no nan/inf). But notice gate weight is entirely zeros - and recall scores = x @ W.T, so with W=0 we get zero scores regardless of input. That's the cause.

December 2 - Grad Norm Issue Due to torch._grouped_gemm Bug

torch._grouped_mm backward pass produces garbage gradients when an expert has 0 tokens.

The Problem

When force load balancing is disabled, meaning the router is trained from scratch, some experts receive no tokens at the beginning. Those experts should have zero gradients. But for this to work properly, torch._grouped_mm expects experts with 0 token count to be padded up to 8 (which is odd). Ideally, this shouldn't be necessary.

The Fix

Added extra padding for experts with zero tokens to 8, ensuring the kernel produces correct gradients during backward pass.

Torchtitan has a workaround for this by using clamp_min set to 8 [link], which technically produces incorrect gradients if the router happens to route exactly 8 tokens to an expert. But in training, since we usually have at least 2M tokens in total, the probability of routing exactly 8 tokens is very small. Still, it is technically a bug originating from torch._grouped_mm.

Verification

Comparison between commit state before the DeepEP PR and the DeepEP PR confirms loss curves and grad norm match ✓

EP = 1
DeepEP PR Reference
EP = 8 (No DeepEP) EP = 8 (With DeepEP)
DeepEP PR Reference DeepEP PR
Force Load Balance = ❌
Loss 7.55 7.57 7.54 7.61 7.52
Grad Norm 1.65 1.83 1.33 1.54 1.01
Force Load Balance = ✓
Loss 7.70 ❌ OOM 7.79 7.97 7.69
Grad Norm 0.84 - 1.70 1.52 0.70

Note: Everything is identical except the commit state, force_load_balance (true or false), and ep_degree (1 or 8). The DeepEP PR with EP=8 matches the reference results.

December 5 - Fused Kernel Rounding Mode Differences

Recovered original end-to-end throughput using new fused [SiLU + expert_output * router_prob] triton kernel. Benchmarking comparison with original PyTorch SiLU below.

Fused kernel benchmark comparison

Loss offset observed during end-to-end training with the triton kernel (without grad norm explosion), caused by gradient differences due to accumulation error from triton's rounding.

Gradient difference due to rounding

The 1-ULP Difference (0.00390625)

This is due to the following:

  • Triton uses a different rounding mode when casting f32 → bf16.
  • PyTorch uses IEEE round-to-nearest-even.

For the value 0.5683592558: The true value lies almost exactly between two bf16 values.

  • PyTorch rounds it to 0.56640625.
  • Triton rounds it to 0.5703125.

Both are valid representations (error ≈ 0.002).

Root Cause: Probability Scaling Position

It seems like the gradient difference is caused by probability scaling position:

Standard EP:  out = (silu(x@w1) * (x@w3) @ w2).float() * prob   [prob AFTER w2]
DeepEP fused: out = ((silu(x@w1) * (x@w3)).float() * prob).bf16() @ w2   [prob BEFORE w2]

Float64 testing confirmed no difference, proving mathematical equivalence → in bf16, order of operations affects numerical precision.

December 9 - Achieving 15,057 TPS with Fused Kernel Optimizations

[Qwen3 30B A3B] Last week I made a final attempt to recover the original 14,796 per-GPU throughput, and now our best optimization has 15,057 tps. I implemented the following three optimizations, and trained them for 5k steps and compared with torchtitan's default EP:

Three Optimizations

1. Wrote a triton kernel to directly fuse expert_output * router_probs (expert multiplication) with the silu activations in the FFN [commit]

2. Wrote another triton kernel to fuse expert multiplication with scatter_add_ [commit]

3. Created a fork of DeepEP and modified it to trade a 50% increase in all-to-all communication volume for higher throughput by fusing the multiplication into DeepEP's combining kernel [our deepep fork].

I didn't add this to torchtitan because I realize optimizations 1 and 2 are already good enough, and this would require refactoring the MoE modeling, so I deprioritized it for now, since it would take too much time.

Throughput Comparison

Optimization throughput comparison

Loss Curves (5k steps)

Loss curves comparison

Gradient Norms (5k steps)

Gradient norms comparison