Training 14B Video Models: Infrastructure Deep Dive
January 12, 2026
Introduction
Training a 14 billion parameter video diffusion model on 720p sequences is an extreme engineering challenge. A single 5-second 720p video at 16 FPS produces roughly 72,000 tokens after VAE encoding—far exceeding typical language model contexts. The quadratic memory scaling of attention makes this infeasible on single GPUs.
This article walks through the complete distributed training infrastructure needed to make this work, with concrete memory calculations at every step using real numbers from production systems.
Working Example: One 720p Video
Let's establish concrete numbers we'll use throughout. Consider a 5-second video clip:
- Resolution: 1280 × 720 (720p)
- Frame rate: 16 FPS
- Duration: 5 seconds
- Total frames: 80 frames
After encoding through a 3D VAE with 16× spatial compression and 4× temporal compression:
This single video clip produces 72,000 tokens. Now let's see why this breaks everything.
Scenario 1: Single GPU (Why It Fails)
Memory Requirements
Consider a 14B parameter DiT (Diffusion Transformer) with:
- Layers: 40
- Hidden dimension ($d$): 5,120
- Attention heads: 40 (head dim = 128)
- FFN dimension: 13,824 (2.7× hidden dim)
Model Parameters (bf16):
Optimizer States (AdamW in fp32): AdamW maintains first moment $m_t$ (momentum) and second moment $v_t$ (variance) for each parameter.
Gradients (bf16):
Total static memory: 28 + 112 + 28 = 168 GB
Already impossible on an H100 with 80 GB! And we haven't even counted activations yet.
Activation Memory (per layer, batch size 1): For self-attention with sequence length $n = 72{,}000$:
- Q, K, V tensors: $3 \times 72{,}000 \times 5{,}120 \times 2 \text{ bytes} = 2.16$ GB
- Attention scores matrix: $72{,}000 \times 72{,}000 \times 2 = 10.37$ GB
- Attention output: $72{,}000 \times 5{,}120 \times 2 = 0.72$ GB
- FFN intermediate: $72{,}000 \times 13{,}824 \times 2 = 1.94$ GB
- FFN output: $72{,}000 \times 5{,}120 \times 2 = 0.72$ GB
Per layer total: 2.16 + 10.37 + 0.72 + 1.94 + 0.72 ≈ 15.9 GB
All 40 layers: 40 × 15.9 = 636 GB
Grand total: 168 GB (static) + 636 GB (activations) = 804 GB
This is 10× the capacity of a single H100 GPU (80 GB). Single GPU training is impossible.
What About Data Parallel (DDP)?
Data Parallel replicates the full model on each GPU. It only helps with throughput—each GPU still needs the full 804 GB. DDP doesn't solve the memory problem at all.
DDP (Data Parallel):
┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │
│ ┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐ │
│ │ Model │ │ │ │ Model │ │ │ │ Model │ │ │ │ Model │ │
│ │ FULL │ │ │ │ FULL │ │ │ │ FULL │ │ │ │ FULL │ │
│ │ 14B │ │ │ │ 14B │ │ │ │ 14B │ │ │ │ 14B │ │
│ │ 804 GB │ │ │ │ 804 GB │ │ │ │ 804 GB │ │ │ │ 804 GB │ │
│ └────────┘ │ │ └────────┘ │ │ └────────┘ │ │ └────────┘ │
│ │ │ │ │ │ │ │
│ Batch 0 │ │ Batch 1 │ │ Batch 2 │ │ Batch 3 │
└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘
│ │ │ │
└─────────────────┴─────────────────┴─────────────────┘
All-Reduce Gradients
(synchronize at end)
Problem: Each GPU needs 804 GB. Doesn't fit!
Use case: Only works when model FITS in single GPU.
DDP only works when the model fits in a single GPU's memory. For large models like 14B parameters, we need a fundamentally different approach.
Scenario 2: Model Sharding with FSDP
Fully Sharded Data Parallel (FSDP)
FSDP shards the model parameters, gradients, and optimizer states across GPUs. Each GPU stores only a fraction. With 32 GPUs:
Now the model state fits! But activations are still a problem (636 GB ÷ 32 = 19.9 GB per GPU for activations alone).
Communication Primitives
1. All-Gather: Each GPU collects shards from all other GPUs to reconstruct the full tensor.
Before All-Gather (each GPU has 1/4 of weights): GPU 0: [W₀][ ][ ][ ] ← only has shard 0 GPU 1: [ ][W₁][ ][ ] ← only has shard 1 GPU 2: [ ][ ][W₂][ ] ← only has shard 2 GPU 3: [ ][ ][ ][W₃] ← only has shard 3 All-Gather Operation: GPU 0 sends W₀ to GPUs 1,2,3 GPU 1 sends W₁ to GPUs 0,2,3 GPU 2 sends W₂ to GPUs 0,1,3 GPU 3 sends W₃ to GPUs 0,1,2 After All-Gather (each GPU has full weights): GPU 0: [W₀][W₁][W₂][W₃] ← full tensor reconstructed GPU 1: [W₀][W₁][W₂][W₃] ← full tensor reconstructed GPU 2: [W₀][W₁][W₂][W₃] ← full tensor reconstructed GPU 3: [W₀][W₁][W₂][W₃] ← full tensor reconstructed Bandwidth: Each GPU sends 1/4 and receives 3/4 = full tensor worth of data
2. Reduce-Scatter: Sum gradients across GPUs, then scatter shards back. This is the inverse of all-gather.
Before Reduce-Scatter (each GPU has full gradients from its local backward):
GPU 0: [g₀⁽⁰⁾][g₁⁽⁰⁾][g₂⁽⁰⁾][g₃⁽⁰⁾] ← gradients from batch 0
GPU 1: [g₀⁽¹⁾][g₁⁽¹⁾][g₂⁽¹⁾][g₃⁽¹⁾] ← gradients from batch 1
GPU 2: [g₀⁽²⁾][g₁⁽²⁾][g₂⁽²⁾][g₃⁽²⁾] ← gradients from batch 2
GPU 3: [g₀⁽³⁾][g₁⁽³⁾][g₂⁽³⁾][g₃⁽³⁾] ← gradients from batch 3
Reduce-Scatter Operation:
Step 1 (Reduce): Sum corresponding gradient shards
Σg₀ = g₀⁽⁰⁾ + g₀⁽¹⁾ + g₀⁽²⁾ + g₀⁽³⁾
Σg₁ = g₁⁽⁰⁾ + g₁⁽¹⁾ + g₁⁽²⁾ + g₁⁽³⁾
(etc. for g₂, g₃)
Step 2 (Scatter): Each GPU keeps only its shard
After Reduce-Scatter (each GPU has 1/4 of summed gradients):
GPU 0: [Σg₀] ← owns gradient shard 0
GPU 1: [Σg₁] ← owns gradient shard 1
GPU 2: [Σg₂] ← owns gradient shard 2
GPU 3: [Σg₃] ← owns gradient shard 3
Now each GPU updates only its parameter shard with its gradient shard.
FSDP Forward/Backward Pass
Forward Pass (per layer):
- All-gather parameters: Temporarily reconstruct full layer weights (28 GB / 32 layers ≈ 875 MB per layer)
- Compute forward: Use full weights to compute layer output
- Free non-local shards: Keep only local 1/32, discard rest to save memory
- Result: Each GPU holds 875 MB / 32 ≈ 27 MB for this layer
Backward Pass (per layer):
- All-gather parameters again: Need full weights for gradient computation
- Compute gradients: Backprop through layer
- Reduce-scatter gradients: Sum gradients across GPUs, each keeps its 1/32 shard
- Free non-local shards: Discard non-local parameter shards again
Optimizer Step: Each GPU updates only its local parameter shard using its local gradient shard.
FSDP with 32 GPUs (showing 4 for clarity): FSDP GROUP (32 GPUs) ─ All process the SAME batch │ ├── GPU 0: Shard 0/32 (5.25 GB) ──┐ ├── GPU 1: Shard 1/32 (5.25 GB) ──┼── All-Gather (forward/backward) ├── GPU 2: Shard 2/32 (5.25 GB) ──┼── Reduce-Scatter (gradients) │ ... │ └── GPU 31: Shard 31/32 (5.25 GB) ──┘ Communication Pattern: Forward: All-Gather params → compute → discard non-local Backward: All-Gather params → compute grads → Reduce-Scatter Each GPU: 5.25 GB (model) + ~19.9 GB (activations) ≈ 25 GB Still tight! Activations dominate.
Problem: Attention Memory Still Too Large
With FSDP-32, we've solved the model memory problem (168 GB → 5.25 GB per GPU). But the 72K × 72K attention matrix still requires 10.37 GB per layer. Even divided by 32 GPUs, we're using most of our memory budget on activations. We need Context Parallelism to split the sequence dimension.
Scenario 3: Adding Context Parallelism
Why Sequence Length is the Problem
The attention score matrix scales quadratically: $\mathcal{O}(n^2)$ where $n$ is sequence length. For $n = 72{,}000$:
With CP=16, each GPU handles $n/16 = 4{,}500$ tokens. But attention needs to see ALL tokens, so:
That's a 16× reduction! From 10.37 GB to 648 MB per layer.
Ulysses Sequence Parallelism
Ulysses splits the sequence across 8 GPUs. The key insight: queries can be local (each GPU computes Q for its chunk), but keys and values must be global (need full K, V for correct attention).
Algorithm:
- Each GPU projects its local input to $Q_{\text{local}}$, $K_{\text{local}}$, $V_{\text{local}}$
- All-gather $K$ and $V$ from all 8 GPUs → get $K_{\text{full}}$, $V_{\text{full}}$
- Compute: $\text{softmax}(Q_{\text{local}} K_{\text{full}}^T / \sqrt{d_k}) \, V_{\text{full}}$
- Each GPU produces output for its local sequence chunk
Ulysses-8 (splitting 72K tokens across 8 GPUs): Original: ████████████████████████████████████████ (72,000 tokens) After Split: GPU 0: █████ (tokens 0 - 8,999) → 9K tokens GPU 1: █████ (tokens 9,000 - 17,999) → 9K tokens GPU 2: █████ (tokens 18,000 - 26,999) → 9K tokens GPU 3: █████ (tokens 27,000 - 35,999) → 9K tokens GPU 4: █████ (tokens 36,000 - 44,999) → 9K tokens GPU 5: █████ (tokens 45,000 - 53,999) → 9K tokens GPU 6: █████ (tokens 54,000 - 62,999) → 9K tokens GPU 7: █████ (tokens 63,000 - 71,999) → 9K tokens Each GPU computes: Q_local: [9K × 5,120] in bf16 → 90 MB K_local: [9K × 5,120] in bf16 → 90 MB V_local: [9K × 5,120] in bf16 → 90 MB All-Gather K, V across 8 GPUs: K_full: [72K × 5,120] → 720 MB V_full: [72K × 5,120] → 720 MB Attention Computation: QK^T: [9K × 5,120] @ [5,120 × 72K] = [9K × 72K] → 1.3 GB (scores) Softmax: apply row-wise Output: [9K × 72K] @ [72K × 5,120] = [9K × 5,120] → 90 MB Memory per GPU: 90 + 720 + 720 + 1,300 + 90 ≈ 2.9 GB per layer
Much better! From 15.9 GB per layer down to 2.9 GB per layer.
Communication Cost of Ulysses
All-gather of K, V requires transferring the full tensors. Per layer:
With NVLink 4.0 bandwidth (900 GB/s within node):
Attention compute time is ~40 ms (FlashAttention-3 on H100), so communication is 1.6/40 = 4% overhead. Acceptable, but we can do better with Ring Attention.
Ring Attention: Overlapping Communication
Ring Attention splits the 8 Ulysses GPUs into rings of 2, reducing communication from all-gather (broadcast pattern) to point-to-point (ring pattern) that overlaps with compute.
Key idea: Don't gather all K, V upfront. Instead, compute attention in blocks while passing K, V around a ring. Use online softmax to maintain numerical stability.
Ring Attention (2-GPU rings within Ulysses-8):
16 GPUs total → 8 independent rings of 2 GPUs each
(8 Ulysses partitions × 2 GPUs per ring = CP-16)
Ring 0: GPU 0 ⟷ GPU 1 (tokens 0-8,999)
Ring 1: GPU 2 ⟷ GPU 3 (tokens 9,000-17,999)
Ring 2: GPU 4 ⟷ GPU 5 (tokens 18,000-26,999)
Ring 3: GPU 6 ⟷ GPU 7 (tokens 27,000-35,999)
Ring 4: GPU 8 ⟷ GPU 9 (tokens 36,000-44,999)
Ring 5: GPU 10 ⟷ GPU 11 (tokens 45,000-53,999)
Ring 6: GPU 12 ⟷ GPU 13 (tokens 54,000-62,999)
Ring 7: GPU 14 ⟷ GPU 15 (tokens 63,000-71,999)
Within each ring (e.g., Ring 0 handling 9K tokens):
GPU 0 holds: Q₀ (4.5K × 5,120), K₀ (4.5K × 5,120), V₀ (4.5K × 5,120)
GPU 1 holds: Q₁ (4.5K × 5,120), K₁ (4.5K × 5,120), V₁ (4.5K × 5,120)
Ring Attention Algorithm (2 steps for 2-GPU ring):
Step 0:
GPU 0: Compute attn(Q₀, K₀, V₀) → partial output₀⁽⁰⁾
GPU 1: Compute attn(Q₁, K₁, V₁) → partial output₁⁽⁰⁾
[Send K, V to ring neighbor]
Step 1:
GPU 0: Compute attn(Q₀, K₁, V₁) → partial output₀⁽¹⁾
(K₁, V₁ received from GPU 1)
GPU 1: Compute attn(Q₁, K₀, V₀) → partial output₁⁽¹⁾
(K₀, V₀ received from GPU 0)
Final Output (online softmax accumulation):
GPU 0: Accumulate(output₀⁽⁰⁾, output₀⁽¹⁾) → final output for tokens 0-4,499
GPU 1: Accumulate(output₁⁽⁰⁾, output₁⁽¹⁾) → final output for tokens 4,500-8,999
Communication: Only point-to-point transfers (4.5K × 5,120 × 2 × 2 bytes = 90 MB)
Overlap: While GPU 0 computes, data transfers in background → near-zero overhead!
Online Softmax for Numerical Stability: When accumulating attention outputs from multiple blocks, we use the numerically stable formulation:
This maintains correctness while computing attention incrementally without storing the full attention matrix.
Memory with CP=16 (Ulysses-8 + Ring-2)
Each GPU now handles 72,000 ÷ 16 = 4,500 tokens.
- Q, K, V (local): $3 \times 4{,}500 \times 5{,}120 \times 2 = 135$ MB
- Attention scores (blockwise): $4{,}500 \times 4{,}500 \times 2 = 40$ MB
- Running statistics: $4{,}500 \times 8 \text{ bytes} = 36$ KB (negligible)
- FFN intermediate: $4{,}500 \times 13{,}824 \times 2 = 122$ MB
Total: ~300 MB per layer
All 40 layers: 12 GB for activations
From 636 GB → 12 GB. 53× reduction!
Complete 128-GPU Architecture
Production systems combine all three parallelism strategies. Here's the full setup with overlapping groups:
128 GPUs: CP=16, FSDP=32, DP=4 (with smart overlapping)
DATA PARALLEL (DP=4)
│
├─► 4 identical groups, each processing different batches
│
└─► FSDP GROUP (32 GPUs per DP replica)
│
├─► Model sharded 32 ways (5.25 GB per GPU)
│
└─► CONTEXT PARALLEL (CP=16)
│
├─► Ulysses-8: Sequence split into 8 partitions
│
│ Partition: 0 1 2 3 4 5 6 7
│ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
│ 9K 9K 9K 9K 9K 9K 9K 9K tokens
│
└─► Ring-2: Each partition handled by 2-GPU ring
Ring 0: GPU0 ↔ GPU1 → tokens 0-8,999
Ring 1: GPU2 ↔ GPU3 → tokens 9,000-17,999
Ring 2: GPU4 ↔ GPU5 → tokens 18,000-26,999
Ring 3: GPU6 ↔ GPU7 → tokens 27,000-35,999
Ring 4: GPU8 ↔ GPU9 → tokens 36,000-44,999
Ring 5: GPU10 ↔ GPU11 → tokens 45,000-53,999
Ring 6: GPU12 ↔ GPU13 → tokens 54,000-62,999
Ring 7: GPU14 ↔ GPU15 → tokens 63,000-71,999
Result: 72K ÷ 16 GPUs = 4.5K tokens per GPU
Activation memory: ~300 MB per layer
TOTAL: 4 DP groups × 32 GPUs = 128 GPUs
Communication Patterns:
• CP (Context Parallel):
- Ulysses: All-gather K,V within 8-GPU group (~1.4 GB, 4% overhead)
- Ring: P2P transfers overlap with compute (<0.5% overhead)
→ Total CP overhead: <1%
• FSDP (Fully Sharded Data Parallel):
- Forward: All-gather params per layer (875 MB per layer)
- Backward: All-gather params + Reduce-scatter grads
→ Overhead: ~10% (well-optimized with overlap)
• DP (Data Parallel):
- All-reduce gradients at end of backward pass
- Overlapped with final layer backward compute
→ Overhead: <1% (fully overlapped)
Total Communication Overhead: ~11-12% (excellent for 128 GPUs!)
Final Memory Budget Per GPU
Let's calculate the complete memory footprint per GPU in the full 128-GPU setup (CP=16, FSDP=32):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Per-GPU Memory Breakdown (H100 80GB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
STATIC MEMORY (Model State with FSDP-32):
Model parameters (local shard): 875 MB
Gradients (local shard): 875 MB
Optimizer states (local shard): 3,500 MB
─────────
Static subtotal: 5,250 MB (5.25 GB)
DYNAMIC MEMORY (Activations with CP-16):
Per layer (40 layers total):
Q, K, V: 3 × 4.5K × 5,120 × 2 = 135 MB
Attention scores: 4.5K × 4.5K × 2 = 40 MB
FFN intermediate: 4.5K × 13,824 × 2= 122 MB
Residual + LayerNorm: 8 MB
─────────
Per layer total: 305 MB
With gradient checkpointing (save every 5 layers):
Stored layers: 40/5 = 8 checkpoints
Memory: 8 × 305 MB = 2,440 MB
Working memory for backward: 1,200 MB
─────────
Activations subtotal: 3,640 MB (3.6 GB)
COMMUNICATION BUFFERS:
All-gather buffers (FSDP): 800 MB
Reduce-scatter buffers: 400 MB
Ring attention buffers: 200 MB
─────────
Buffers subtotal: 1,400 MB (1.4 GB)
BATCH DATA:
Input latents (4.5K tokens × 16 ch): 180 MB
Text embeddings (512 tokens): 50 MB
Timestep embeddings: 10 MB
─────────
Data subtotal: 240 MB (0.24 GB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
TOTAL USED: 10.5 GB (13% of 80 GB)
HEADROOM: 69.5 GB (87% free!)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Result: Substantial headroom allows longer sequences or larger batches!
Additional Memory Optimizations
- FlashAttention-3: Fused kernel that never materializes full $n \times n$ attention matrix. Uses GPU SRAM (on-chip memory) for tiling. Achieves 200× memory reduction and 2-4× speedup versus naive attention.
- Mixed Precision (bf16): Forward/backward in bfloat16 (16 bits), master weights in fp32 (32 bits). Gradients accumulated in fp32 for numerical stability. Provides 2× memory reduction for activations and 2× speedup on Ampere/Hopper GPUs.
- Activation Offloading: For sequences >1M tokens, offload activation checkpoints to CPU RAM via PCIe. Transfer time (64 GB/s) overlaps with layer compute time. Enables virtually unlimited sequence length.
- Selective Checkpointing: Only checkpoint transformer blocks (expensive), not VAE encoder/decoder or text encoder (cheap). Reduces checkpoint count from 42 modules to 40, saving 5% memory.
Other Parallelism Strategies
While FSDP and Context Parallelism are central to our video model training, large-scale LLM training often employs additional parallelism strategies. Understanding these is essential for comprehensive distributed training knowledge.
Tensor Parallelism (TP)
Tensor Parallelism splits individual layers across multiple GPUs. Unlike FSDP (which shards by parameter ownership), TP partitions the actual matrix operations.
Column-wise parallelism (for first projection):
Row-wise parallelism (for second projection):
Tensor Parallelism (TP=2) for Attention:
Input X: [batch × seq × hidden]
↓
┌───────┴───────┐
↓ ↓
GPU 0 GPU 1
Q₀,K₀,V₀ Q₁,K₁,V₁
(heads 0-19) (heads 20-39)
↓ ↓
Attention₀ Attention₁
↓ ↓
└───────┬───────┘
↓ (All-Reduce)
Output: [batch × seq × hidden]
Memory savings: Each GPU stores 1/TP of attention weights
Communication: All-Reduce after attention output projection
When to use TP vs FSDP:
- TP: Low latency (single collective per layer), best within a node (NVLink)
- FSDP: Higher latency (all-gather + reduce-scatter), better across nodes
- Combined: TP within nodes, FSDP across nodes is common in production
Pipeline Parallelism (PP)
Pipeline Parallelism assigns different layers to different GPUs, forming a processing pipeline. This is particularly useful when individual layers fit in GPU memory but the full model doesn't.
Pipeline Parallelism (PP=4 stages, 40 layers): GPU 0: Layers 0-9 ━━━▶ GPU 1: Layers 10-19 ━━━▶ GPU 2: Layers 20-29 ━━━▶ GPU 3: Layers 30-39 ━━━▶ Naive Schedule (high bubble overhead): ┌─────────────────────────────────────────────────┐ │ GPU 0: [F0] [B0] │ │ GPU 1: [F0] [B0] │ │ GPU 2: [F0] [B0] │ │ GPU 3: [F0][B0] │ │ ↑ bubble (idle time) │ └─────────────────────────────────────────────────┘ 1F1B Schedule (minimized bubbles): ┌─────────────────────────────────────────────────┐ │ GPU 0: [F0][F1][F2][F3][B0][B1][B2][B3] │ │ GPU 1: [F0][F1][F2][B0][F3][B1][B2][B3] │ │ GPU 2: [F0][F1][B0][F2][B1][F3][B2][B3]│ │ GPU 3: [F0][B0][F1][B1][F2][B2]... │ └─────────────────────────────────────────────────┘ F = Forward microbatch, B = Backward microbatch
Pipeline bubble overhead:
Where $p$ = pipeline stages, $m$ = microbatches. With PP=4 and m=16: bubble = 3/19 ≈ 16%. Interleaved schedules like 1F1B reduce this significantly.
ZeRO: The Foundation of FSDP
Zero Redundancy Optimizer (ZeRO) is the theoretical framework behind FSDP. Understanding its stages clarifies why FSDP works:
ZeRO Stages (14B model, 32 GPUs):
│ Per-GPU Memory │ Communication
────────────────────┼────────────────┼──────────────────
Baseline (DDP) │ 168 GB │ All-Reduce grads
│ │
ZeRO-1 (Optimizer) │ 84 GB │ All-Reduce grads
Shard optimizer │ (168-112)/32+56│ + All-Gather optimizer
│ │
ZeRO-2 (+ Grads) │ 59 GB │ Reduce-Scatter grads
Shard optimizer │ (168-112-28)/32│ + All-Gather grads
+ gradients │ + 56 │
│ │
ZeRO-3 (+ Params) │ 5.25 GB │ All-Gather params (fwd)
Shard everything │ 168/32 │ All-Gather params (bwd)
= FSDP │ │ Reduce-Scatter grads
────────────────────┴────────────────┴──────────────────
FSDP = ZeRO-3 implemented in PyTorch
Gradient Accumulation
Gradient Accumulation enables large effective batch sizes when GPU memory is limited. Instead of updating weights after each forward-backward pass, gradients are accumulated over multiple micro-batches:
Gradient Accumulation (accumulation_steps=4):
Step 1: forward(batch_0) → backward() → grad += ∇L₀
Step 2: forward(batch_1) → backward() → grad += ∇L₁
Step 3: forward(batch_2) → backward() → grad += ∇L₂
Step 4: forward(batch_3) → backward() → grad += ∇L₃
↓
optimizer.step() ← update with averaged gradients
grad = 0 ← reset for next accumulation cycle
Memory: Only micro-batch activations stored (not full batch)
Effective: Same as 4× larger batch, but 4× slower per step
For our video model with CP=16, FSDP=32, DP=4, and accumulation_steps=8:
3D/4D Parallelism: Combining Strategies
Production systems combine multiple parallelism strategies. The key is understanding which strategies are orthogonal and can be composed:
4D Parallelism for 1024-GPU Training:
DP=8 (Data Parallel - 8 replicas for different batches)
│
└─► PP=4 (Pipeline Parallel - 4 stages of layers)
│
└─► TP=8 (Tensor Parallel - 8 GPUs split each layer)
│
└─► CP=4 (Context Parallel - 4-way sequence split)
Total GPUs = DP × PP × TP × CP = 8 × 4 × 8 × 4 = 1,024
Communication hierarchy (fastest → slowest):
TP: NVSwitch within node (900 GB/s) ← tightest coupling
CP: NVLink across 2 nodes (400 GB/s)
PP: InfiniBand point-to-point (100 GB/s)
DP: InfiniBand all-reduce (100 GB/s) ← most independent
Fault Tolerance and Checkpointing
Long training runs (days to weeks) inevitably encounter hardware failures. Robust checkpointing is essential:
- Synchronous checkpointing: All GPUs pause training, write to shared storage. Simple but causes ~5-10 minute stalls for large models.
- Asynchronous checkpointing: Copy model state to CPU memory in background, write to disk while training continues. Reduces overhead to seconds.
- Distributed checkpointing: Each GPU writes its local shard directly. With FSDP-32, creates 32 smaller files instead of one 168 GB file. 10× faster I/O.
- Elastic training: Frameworks like TorchElastic can restart training with different GPU counts after failures, reassigning shards automatically.
Checkpoint frequency trade-off:
With Mean Time Between Failures (MTBF) of 24 hours for a 128-GPU cluster, checkpointing every 30 minutes means ~15 minutes of lost work per failure on average.
Data Pipeline Overview
Training on O(10¹²) tokens requires sophisticated data infrastructure. The challenge: random access to millions of video files creates catastrophic I/O contention. The solution: multi-stage pipeline with offline preprocessing and sharding.
Five-Stage Architecture
- Raw Storage (122 PB): ProRes 4444 and DNxHR HQX 4K video files stored in MinIO object storage (S3-compatible). Typical file size: 900 GB - 1.35 TB per hour of 4K footage.
- Offline Preprocessing (6 months): Dedicated cluster of 50 CPU nodes (AMD EPYC 9654, 96 cores each) performs: video decode (hardware-accelerated), color space conversion (Log/PQ/HLG → Linear RGB), quality filtering (VMAF > 80, discard artifacts), temporal segmentation (32-frame clips with scene coherence), multi-resolution rendering (256px, 480px, 720p).
- Sharding (30 PB output): Pack ~3,000 clips into 120 GB TAR archives (WebDataset format). Creates 250,000 shards total. Each shard contains frames as 12-bit PNG (lossless) + JSON metadata (HDR tags, color space, motion statistics).
- Distributed Storage (Lustre): 20 storage servers (OSS) + 2 metadata servers (MDS) provide 2 TB/s aggregate bandwidth. Each shard striped across 4 servers with 1 MB stripe size. 100 GbE network per node.
- Local NVMe Caching: Each training node has 10 TB NVMe SSD cache. LRU eviction policy achieves 85% hit rate (frequently accessed shards stay cached). 8 CPU workers per GPU perform async prefetch from Lustre while GPUs compute.
Bandwidth Reality Check: Reading 30 PB over 300 training days requires only 3.5 GB/s sustained—trivial compared to 2 TB/s Lustre capacity. The key: data is read slowly over months, not all at once. With 85% cache hit rate, actual Lustre load is only 0.5 GB/s. Data loading is never the bottleneck.
Multi-Stage Training Curriculum
Direct 720p training is inefficient. Progressive resolution scaling enables better convergence:
| Stage | Resolution | Tokens/Sample | Batch Size | Purpose |
|---|---|---|---|---|
| 1 | 256px images | ~1,024 | 256 | Spatial composition, object semantics |
| 2 | 192px video + 256px images |
~11,520 | 32 images 4 videos |
Basic motion, temporal consistency |
| 3 | 480px video + images |
~72,000 | 8 images 2 videos |
Fine textures, improved details |
| 4 | 720p video + images |
~288,000 | 4 images 1 video |
Production quality, photorealism |
Why this works:
- Images are abundant (billions available) and cheap to process (1K tokens vs 288K)
- Low resolution enables large batches → stable gradients, faster iteration
- Spatial understanding from images transfers well to video
- Model builds strong priors before tackling expensive long sequences
- 10:1 image-to-video ratio maintained in stages 2-4 for balanced learning
Training Configuration Per Stage:
- Stage 1: DP=128 (no CP/FSDP needed), 50K samples/hour, ~10 GB/GPU
- Stage 2: CP=4, FSDP=8, DP=4 → 32 GPUs, 5K samples/hour, ~40 GB/GPU
- Stage 3: CP=8, FSDP=16, DP=2 → 64 GPUs, 1.2K samples/hour, ~65 GB/GPU
- Stage 4: CP=16, FSDP=32, DP=2 → 128 GPUs, 400 samples/hour, ~76 GB/GPU
Real-World Results
This infrastructure has been validated in production:
- WAN 2.2 (14B): Trained on 128 H20 GPUs (Alibaba Cloud) with 95% MFU (Model FLOPs Utilization). For comparison: typical transformers achieve 40-60% MFU; 95% is exceptional.
- Maximum Sequence Length: Up to 1M tokens demonstrated (40-second 720p videos or higher resolution). At 1M tokens, attention computation dominates 95% of total training time.
- Communication Overhead: 2D Context Parallelism reduces overhead from >10% (naive Ulysses) to <1%. FSDP adds ~10%. Total: 11-12% for 128 GPUs.
- Inference Performance: 8× A100 GPUs (80 GB each) generate 15-minute videos in real-time at 8 FPS. Total: 7,200 frames generated in 15 minutes = real-time generation.
- Consumer Accessibility: 1.3B parameter variant requires only 8.19 GB VRAM, runs on RTX 4090 or RTX 3090. Achieves competitive quality despite 10× fewer parameters.
Conclusion
Training 14B parameter video models on 720p sequences requires orchestrating multiple complementary techniques:
- FSDP (ZeRO-3) shards model state (params + grads + optimizer) across 32 GPUs: 168 GB → 5.25 GB per GPU
- Context Parallelism splits sequences across 16 GPUs: 636 GB → 12 GB activations per GPU
- 2D CP (Ulysses + Ring) reduces communication from >10% to <1% overhead through clever overlapping
- Tensor Parallelism splits attention heads and FFN layers within nodes for low-latency compute
- Pipeline Parallelism distributes layers across GPUs with 1F1B scheduling to minimize bubbles
- Gradient Accumulation enables large effective batch sizes despite per-GPU memory limits
- FlashAttention provides 200× memory reduction and 2-4× speedup for attention operations
- Gradient Checkpointing trades 33% extra compute for 84% activation memory savings
- Mixed Precision (bf16/fp32) cuts activation memory and compute time by 2×
- Data Pipeline with 5-stage architecture ensures 100% GPU utilization despite petabyte-scale datasets
- Curriculum Learning through 4-stage progressive resolution scaling accelerates convergence
- Fault Tolerance with distributed checkpointing enables recovery from inevitable hardware failures
The key insight: these techniques are complementary and necessary. FSDP solves model memory. Context Parallelism solves activation memory. Tensor and Pipeline Parallelism enable further scaling. FlashAttention solves compute efficiency. The data pipeline solves I/O. Remove any piece and the system breaks.
This infrastructure represents years of engineering across distributed systems, numerical optimization, and machine learning. But the core principles are timeless: partition what doesn't fit, coordinate through communication, overlap wherever possible. As we push toward photorealistic, high-resolution, long-form video generation, these patterns will continue to enable the impossible.