Training 14B Video Models: Infrastructure Deep Dive

January 12, 2026

Introduction

Training a 14 billion parameter video diffusion model on 720p sequences is an extreme engineering challenge. A single 5-second 720p video at 16 FPS produces roughly 72,000 tokens after VAE encoding—far exceeding typical language model contexts. The quadratic memory scaling of attention makes this infeasible on single GPUs.

This article walks through the complete distributed training infrastructure needed to make this work, with concrete memory calculations at every step using real numbers from production systems.

Foundations: Build Your Mental Model

Before we solve the problem, let's understand WHY it's hard. Work through each tab—by the end, you'll see exactly why we need every technique in this article.

From Pixels to Tokens: The VAE Compression

A neural network doesn't see "pixels"—it sees tokens. For video, we use a 3D VAE (Variational Autoencoder) that compresses raw pixels into a latent space.

Our Video:
Resolution: 1280 × 720 pixels (720p)
Frame rate: 16 FPS
Duration: 5 seconds
Total frames: 16 × 5 = 80 frames
Raw pixels: 1280 × 720 × 80 = 73.7 million pixels

Step 1: Spatial Compression (16× in each dimension)

The VAE encoder compresses each frame spatially. Think of it like reducing a 1280×720 image to an 80×45 "semantic map" where each cell encodes meaning, not just color.

Spatial compression per frame:
Width: 1280 ÷ 16 = 80
Height: 720 ÷ 16 = 45
Spatial tokens per frame: 80 × 45 = 3,600 tokens
73.7M pixels → 3,600 semantic tokens per frame (20,000× compression!)

Step 2: Temporal Compression (4× across time)

Adjacent frames are similar. The 3D VAE also compresses across time, grouping every 4 frames into 1 temporal slice.

Temporal compression:
Original frames: 80
Temporal factor: 4
Temporal tokens: 80 ÷ 4 = 20 time steps

Final Token Count

Total tokens = Spatial × Temporal
= 3,600 tokens/frame × 20 time steps
= 72,000 tokens for a 5-second 720p video
Key Insight: 72,000 tokens is MASSIVE. GPT-4 has a 128K context window—and that's text, which is cheap. For video, each token is a 16-channel latent vector (not a single word). This sequence length is the root of all our problems.

What Actually Lives in GPU Memory?

When training a neural network, your GPU holds 4 types of data. Let's do the math for a 14B parameter model.

1. Model Parameters

The weights themselves. With bfloat16 (2 bytes per parameter):

Parameters: 14,000,000,000
Bytes: 14B × 2 bytes = 28,000,000,000 bytes
= 28 GB for model weights

2. Gradients

During backprop, we compute ∂Loss/∂Weight for every parameter. Same size as parameters:

Gradients: 14B × 2 bytes (bf16)
= 28 GB for gradients

3. Optimizer States (The Hidden Monster)

AdamW maintains two running statistics per parameter:

AdamW update rule:
$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$   ← first moment (momentum)
$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$   ← second moment (variance)
These are stored in fp32 (4 bytes) for numerical stability!
Optimizer states: 14B × 2 states × 4 bytes = 112 GB
Surprise! Optimizer states are 4× larger than the model itself! This is why Adam is memory-hungry. SGD with momentum would only need 28 GB (one state), but converges worse.

4. Activations (The Quadratic Killer)

Every layer saves intermediate values for backprop. This depends on sequence length and scales quadratically for attention.

Per layer for 72K sequence:
Q, K, V projections: 3 × 72K × 5,120 × 2 bytes = 2.16 GB
Attention scores: 72K × 72K × 2 bytes = 10.37 GB ← THE PROBLEM
FFN intermediates: 72K × 13,824 × 2 bytes = 1.94 GB
Per layer total: ~15.9 GB × 40 layers = 636 GB

The Full Picture: H100 GPU (80 GB)

28GB
28GB
112GB
636GB

Params (28GB) Grads (28GB) Optimizer (112GB) Activations (636GB)

Total needed: 28 + 28 + 112 + 636 = 804 GB
H100 has: 80 GB
We need 10× what we have!

Why Attention Scales Quadratically

Attention is the core operation that lets tokens "see" other tokens. Let's work through it step by step.

Step 1: Project to Q, K, V

Each token gets projected into Query, Key, and Value vectors:

Input X: [n × d] where n=72,000 tokens, d=5,120 dimensions
$Q = XW_Q$   shape: [72K × 5,120]
$K = XW_K$   shape: [72K × 5,120]
$V = XW_V$   shape: [72K × 5,120]
Memory: 3 × 72K × 5,120 × 2 bytes = 2.16 GB ← Linear in n, OK!

Step 2: Compute Attention Scores (THE QUADRATIC STEP)

Each query asks "how relevant is each key?" This creates an n×n matrix:

$\text{Scores} = QK^T / \sqrt{d_k}$
Q shape: [72K × 5,120]
K^T shape: [5,120 × 72K]
Result: [72K × 72K] = 5.18 billion elements!
Memory: 72,000 × 72,000 × 2 bytes = 10.37 GB per layer
Q (queries) K^T (keys transposed) Scores ┌─────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │ token 0 → │ │ ↓ ↓ ↓ ↓ │ │ 0→0 0→1 0→2 ... │ │ token 1 → │ @ │ t0 t1 t2 ... │ = │ 1→0 1→1 1→2 ... │ │ token 2 → │ │ │ │ 2→0 2→1 2→2 ... │ │ ... │ │ │ │ ... │ │ token 71999 │ │ t71999 │ │ 72K × 72K │ └─────────────┘ └─────────────────────┘ └─────────────────────┘ [72K × 5120] [5120 × 72K] [72K × 72K] = 10.37 GB!

Step 3: Softmax and Weighted Sum

$\text{Attention} = \text{softmax}(\text{Scores}) \cdot V$
Softmax: apply row-wise, still [72K × 72K]
Multiply by V [72K × 5120] → Output [72K × 5120]

The Scaling Problem

Memory vs Sequence Length:
n = 1,000 tokens: 1K × 1K × 2 = 2 MB ✓
n = 8,000 tokens (GPT-4 short): 8K × 8K × 2 = 128 MB ✓
n = 32,000 tokens: 32K × 32K × 2 = 2 GB ⚠️
n = 72,000 tokens (our video): 72K × 72K × 2 = 10.37 GB
n = 1,000,000 tokens: 1M × 1M × 2 = 2 TB ❌❌❌
Doubling sequence length → 4× memory. This is O(n²).
The Core Problem: Attention is the reason transformers work so well (every token can attend to every other token). But it's also why long sequences are impossible on single GPUs. We must SPLIT the sequence across GPUs.

How GPUs Talk to Each Other

When we split work across GPUs, they need to share data. There are 3 fundamental operations you must understand.

1. All-Reduce: Sum Across All GPUs

Every GPU has a local value. After all-reduce, every GPU has the SUM of all values.

BEFORE All-Reduce: AFTER All-Reduce: GPU 0: [3] GPU 0: [3+5+2+7] = [17] GPU 1: [5] ───────► GPU 1: [3+5+2+7] = [17] GPU 2: [2] GPU 2: [3+5+2+7] = [17] GPU 3: [7] GPU 3: [3+5+2+7] = [17] Use case: Averaging gradients in Data Parallel training
Communication cost:
Each GPU sends: full tensor size
Total data moved: 2 × (N-1)/N × tensor_size ≈ 2 × tensor_size
For 28 GB gradients across 4 GPUs: ~56 GB total network traffic

2. All-Gather: Collect Shards from Everyone

Each GPU has 1/N of the data. After all-gather, every GPU has the FULL tensor.

BEFORE All-Gather: AFTER All-Gather: GPU 0: [A][_][_][_] GPU 0: [A][B][C][D] ← full tensor GPU 1: [_][B][_][_] ───────► GPU 1: [A][B][C][D] ← full tensor GPU 2: [_][_][C][_] GPU 2: [A][B][C][D] ← full tensor GPU 3: [_][_][_][D] GPU 3: [A][B][C][D] ← full tensor Use case: FSDP reconstructs full layer weights before forward pass
Communication cost:
Each GPU sends: 1/N of tensor
Each GPU receives: (N-1)/N of tensor
For 875 MB shard across 32 GPUs: Each GPU receives 27 GB

3. Reduce-Scatter: Sum Then Distribute Shards

Every GPU has a full tensor. After reduce-scatter, each GPU has 1/N of the SUMMED result. This is the inverse of all-gather!

BEFORE Reduce-Scatter: AFTER Reduce-Scatter: GPU 0: [a₀][a₁][a₂][a₃] GPU 0: [Σa₀] (= a₀+b₀+c₀+d₀) GPU 1: [b₀][b₁][b₂][b₃] ───────► GPU 1: [Σa₁] (= a₁+b₁+c₁+d₁) GPU 2: [c₀][c₁][c₂][c₃] GPU 2: [Σa₂] (= a₂+b₂+c₂+d₂) GPU 3: [d₀][d₁][d₂][d₃] GPU 3: [Σa₃] (= a₃+b₃+c₃+d₃) Use case: FSDP after backward pass - sum gradients and keep only your shard
Communication cost:
Same as all-gather: (N-1)/N × tensor_size per GPU
All-gather + Reduce-scatter = 2 × All-reduce cost

Network Bandwidth Reality

H100 connectivity:
NVLink 4.0 (within node, 8 GPUs): 900 GB/s bidirectional
InfiniBand NDR (across nodes): 400 Gb/s = 50 GB/s per link
Time to all-gather 28 GB within node: 28/900 = 31 ms
Time to all-gather 28 GB across nodes: 28/50 = 560 ms ← 18× slower!
Lesson: Keep communication within nodes when possible. Cross-node = expensive.
Key Insight: Every parallelism strategy is a different trade-off between computation and communication. The goal: overlap communication with compute so GPUs are never waiting.

The Three Ways to Split Work

There are only 3 fundamental axes to parallelize training. Understanding these is the key to everything.

Axis 1: Data Parallelism (DP)

Idea: Same model, different data. Each GPU processes a different batch, then average gradients.

Data Parallel (DP=4): ┌─────────────────────────────────────────┐ │ Same Model (14B params) │ └─────────────────────────────────────────┘ ↓ copied to each GPU ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │ │ Model: 14B │ │ Model: 14B │ │ Model: 14B │ │ Model: 14B │ │ Batch: 0 │ │ Batch: 1 │ │ Batch: 2 │ │ Batch: 3 │ │ ↓ │ │ ↓ │ │ ↓ │ │ ↓ │ │ grad₀ │ │ grad₁ │ │ grad₂ │ │ grad₃ │ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │ │ │ │ └─────────────────┴─────────────────┴─────────────────┘ All-Reduce: avg(grad₀, grad₁, grad₂, grad₃) ✓ Easy to implement ✗ Every GPU needs full model memory (804 GB) - DOESN'T FIT!

Axis 2: Model Parallelism (split the model)

Idea: Different GPUs hold different parts of the model.

Model Parallel - Two flavors: TENSOR PARALLEL (split layers horizontally): ┌─────────────────────────────────────────┐ │ Layer N │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │ Heads │ │ Heads │ │ Heads │ │ │ │ 0-13 │ │ 14-26 │ │ 27-39 │ │ │ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ │ └─────────┘ └─────────┘ └─────────┘ │ └─────────────────────────────────────────┘ Each GPU: 1/3 of layer params, needs all-reduce after attention PIPELINE PARALLEL (split layers vertically): ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ Layers │ │ Layers │ │ Layers │ │ 0-13 │ ──► │ 14-26 │ ──► │ 27-39 │ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ └──────────────┘ └──────────────┘ └──────────────┘ Each GPU: 1/3 of layers, data flows through pipeline ✓ Reduces model memory per GPU ✗ Complex communication patterns

Axis 3: Sequence/Context Parallelism (CP)

Idea: Different GPUs process different chunks of the sequence.

Context Parallel (CP=4): Full sequence: [████████████████████████████████] (72,000 tokens) ↓ split ┌─────────────────┼─────────────────┐ ↓ ↓ ↓ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │ │ Tokens │ │ Tokens │ │ Tokens │ │ Tokens │ │ 0-17999 │ │ 18000-35999 │ │ 36000-53999 │ │ 54000-71999 │ │ (18K each) │ │ (18K each) │ │ (18K each) │ │ (18K each) │ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ Attention memory per GPU: 18K × 72K × 2 = 2.6 GB (not 10.37 GB!) But needs communication to share K, V across GPUs for full attention ✓ Directly attacks the O(n²) attention problem ✗ Requires sophisticated communication (Ulysses, Ring Attention)

FSDP: The Best of Both Worlds

Fully Sharded Data Parallel combines DP with sharding. Each GPU holds 1/N of the model, reconstructs layers on-demand.

FSDP (DP + Sharding): ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │ GPU 3 │ │ Shard 0/4 │ │ Shard 1/4 │ │ Shard 2/4 │ │ Shard 3/4 │ │ (42 GB) │ │ (42 GB) │ │ (42 GB) │ │ (42 GB) │ │ Batch 0 │ │ Batch 1 │ │ Batch 2 │ │ Batch 3 │ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │ │ │ │ └── All-Gather before compute (reconstruct full layer) ──┘ └── Reduce-Scatter after backward (sum & shard grads) ──┘ ✓ Model memory: 168 GB ÷ 4 = 42 GB per GPU ← FITS! ✓ Still data parallel (process different batches) ✗ More communication than pure DP
The Insight: We need ALL THREE:
  • FSDP to fit the model (168 GB → 5.25 GB per GPU with 32-way shard)
  • Context Parallel to fit activations (636 GB → 12 GB with 16-way split)
  • Data Parallel on top to increase throughput

Working Example: One 720p Video

Now that you understand the foundations, let's apply them. We'll establish concrete numbers and carry them through the entire article. You've seen the concepts—now watch them break things.

Let's establish concrete numbers we'll use throughout. Consider a 5-second video clip:

After encoding through a 3D VAE with 16× spatial compression and 4× temporal compression:

$$\text{Spatial tokens per frame} = \frac{1280}{16} \times \frac{720}{16} = 80 \times 45 = 3{,}600$$
$$\text{Temporal tokens} = \frac{80 \text{ frames}}{4} = 20$$
$$\text{Total tokens} = 3{,}600 \times 20 = 72{,}000$$

This single video clip produces 72,000 tokens. Now let's see why this breaks everything.

Scenario 1: Single GPU (Why It Fails)

Memory Requirements

Consider a 14B parameter DiT (Diffusion Transformer) with:

Model Parameters (bf16):

$$14{,}000{,}000{,}000 \times 2 \text{ bytes} = 28 \text{ GB}$$

Optimizer States (AdamW in fp32): AdamW maintains first moment $m_t$ (momentum) and second moment $v_t$ (variance) for each parameter.

$$14B \times 2 \text{ states} \times 4 \text{ bytes (fp32)} = 112 \text{ GB}$$

Gradients (bf16):

$$14B \times 2 \text{ bytes} = 28 \text{ GB}$$

Total static memory: 28 + 112 + 28 = 168 GB

Already impossible on an H100 with 80 GB! And we haven't even counted activations yet.

Activation Memory (per layer, batch size 1): For self-attention with sequence length $n = 72{,}000$:

Per layer total: 2.16 + 10.37 + 0.72 + 1.94 + 0.72 ≈ 15.9 GB

All 40 layers: 40 × 15.9 = 636 GB

Grand total: 168 GB (static) + 636 GB (activations) = 804 GB

This is 10× the capacity of a single H100 GPU (80 GB). Single GPU training is impossible.

What About Data Parallel (DDP)?

Data Parallel replicates the full model on each GPU. It only helps with throughput—each GPU still needs the full 804 GB. DDP doesn't solve the memory problem at all.

DDP (Data Parallel):

┌──────────────┐  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐
│   GPU 0      │  │   GPU 1      │  │   GPU 2      │  │   GPU 3      │
│  ┌────────┐  │  │  ┌────────┐  │  │  ┌────────┐  │  │  ┌────────┐  │
│  │ Model  │  │  │  │ Model  │  │  │  │ Model  │  │  │  │ Model  │  │
│  │  FULL  │  │  │  │  FULL  │  │  │  │  FULL  │  │  │  │  FULL  │  │
│  │  14B   │  │  │  │  14B   │  │  │  │  14B   │  │  │  │  14B   │  │
│  │ 804 GB │  │  │  │ 804 GB │  │  │  │ 804 GB │  │  │  │ 804 GB │  │
│  └────────┘  │  │  └────────┘  │  │  └────────┘  │  │  └────────┘  │
│              │  │              │  │              │  │              │
│  Batch 0     │  │  Batch 1     │  │  Batch 2     │  │  Batch 3     │
└──────────────┘  └──────────────┘  └──────────────┘  └──────────────┘
       │                 │                 │                 │
       └─────────────────┴─────────────────┴─────────────────┘
                    All-Reduce Gradients
                    (synchronize at end)

Problem: Each GPU needs 804 GB. Doesn't fit!
Use case: Only works when model FITS in single GPU.

DDP only works when the model fits in a single GPU's memory. For large models like 14B parameters, we need a fundamentally different approach.

Checkpoint: Connect to Foundations
If you worked through Tab 2 (GPU Memory) and Tab 5 (Parallelism Models), you saw this coming. DDP copies the full model to each GPU—great for throughput, useless for memory. We need to SHARD the model state. That's FSDP.

Scenario 2: Model Sharding with FSDP

Fully Sharded Data Parallel (FSDP)

FSDP shards the model parameters, gradients, and optimizer states across GPUs. Each GPU stores only a fraction. With 32 GPUs:

$$\text{Parameters per GPU} = \frac{28 \text{ GB}}{32} = 875 \text{ MB}$$
$$\text{Gradients per GPU} = \frac{28 \text{ GB}}{32} = 875 \text{ MB}$$
$$\text{Optimizer states per GPU} = \frac{112 \text{ GB}}{32} = 3.5 \text{ GB}$$
$$\text{Total static per GPU} = 875 + 875 + 3{,}500 = 5{,}250 \text{ MB} = 5.25 \text{ GB}$$

Now the model state fits! But activations are still a problem (636 GB ÷ 32 = 19.9 GB per GPU for activations alone).

Communication Primitives

1. All-Gather: Each GPU collects shards from all other GPUs to reconstruct the full tensor.

Before All-Gather (each GPU has 1/4 of weights):

GPU 0: [W₀][  ][  ][  ]  ← only has shard 0
GPU 1: [  ][W₁][  ][  ]  ← only has shard 1
GPU 2: [  ][  ][W₂][  ]  ← only has shard 2
GPU 3: [  ][  ][  ][W₃]  ← only has shard 3

All-Gather Operation:
  GPU 0 sends W₀ to GPUs 1,2,3
  GPU 1 sends W₁ to GPUs 0,2,3
  GPU 2 sends W₂ to GPUs 0,1,3
  GPU 3 sends W₃ to GPUs 0,1,2

After All-Gather (each GPU has full weights):

GPU 0: [W₀][W₁][W₂][W₃]  ← full tensor reconstructed
GPU 1: [W₀][W₁][W₂][W₃]  ← full tensor reconstructed
GPU 2: [W₀][W₁][W₂][W₃]  ← full tensor reconstructed
GPU 3: [W₀][W₁][W₂][W₃]  ← full tensor reconstructed

Bandwidth: Each GPU sends 1/4 and receives 3/4 = full tensor worth of data

2. Reduce-Scatter: Sum gradients across GPUs, then scatter shards back. This is the inverse of all-gather.

Before Reduce-Scatter (each GPU has full gradients from its local backward):

GPU 0: [g₀⁽⁰⁾][g₁⁽⁰⁾][g₂⁽⁰⁾][g₃⁽⁰⁾]  ← gradients from batch 0
GPU 1: [g₀⁽¹⁾][g₁⁽¹⁾][g₂⁽¹⁾][g₃⁽¹⁾]  ← gradients from batch 1
GPU 2: [g₀⁽²⁾][g₁⁽²⁾][g₂⁽²⁾][g₃⁽²⁾]  ← gradients from batch 2
GPU 3: [g₀⁽³⁾][g₁⁽³⁾][g₂⁽³⁾][g₃⁽³⁾]  ← gradients from batch 3

Reduce-Scatter Operation:
  Step 1 (Reduce): Sum corresponding gradient shards
    Σg₀ = g₀⁽⁰⁾ + g₀⁽¹⁾ + g₀⁽²⁾ + g₀⁽³⁾
    Σg₁ = g₁⁽⁰⁾ + g₁⁽¹⁾ + g₁⁽²⁾ + g₁⁽³⁾
    (etc. for g₂, g₃)

  Step 2 (Scatter): Each GPU keeps only its shard

After Reduce-Scatter (each GPU has 1/4 of summed gradients):

GPU 0: [Σg₀]  ← owns gradient shard 0
GPU 1: [Σg₁]  ← owns gradient shard 1
GPU 2: [Σg₂]  ← owns gradient shard 2
GPU 3: [Σg₃]  ← owns gradient shard 3

Now each GPU updates only its parameter shard with its gradient shard.

FSDP Forward/Backward Pass

Forward Pass (per layer):

  1. All-gather parameters: Temporarily reconstruct full layer weights (28 GB / 32 layers ≈ 875 MB per layer)
  2. Compute forward: Use full weights to compute layer output
  3. Free non-local shards: Keep only local 1/32, discard rest to save memory
  4. Result: Each GPU holds 875 MB / 32 ≈ 27 MB for this layer

Backward Pass (per layer):

  1. All-gather parameters again: Need full weights for gradient computation
  2. Compute gradients: Backprop through layer
  3. Reduce-scatter gradients: Sum gradients across GPUs, each keeps its 1/32 shard
  4. Free non-local shards: Discard non-local parameter shards again

Optimizer Step: Each GPU updates only its local parameter shard using its local gradient shard.

FSDP with 32 GPUs (showing 4 for clarity):

FSDP GROUP (32 GPUs) ─ All process the SAME batch
│
├── GPU 0:  Shard 0/32  (5.25 GB) ──┐
├── GPU 1:  Shard 1/32  (5.25 GB) ──┼── All-Gather (forward/backward)
├── GPU 2:  Shard 2/32  (5.25 GB) ──┼── Reduce-Scatter (gradients)
│   ...                             │
└── GPU 31: Shard 31/32 (5.25 GB) ──┘

Communication Pattern:
  Forward:  All-Gather params → compute → discard non-local
  Backward: All-Gather params → compute grads → Reduce-Scatter

Each GPU: 5.25 GB (model) + ~19.9 GB (activations) ≈ 25 GB
Still tight! Activations dominate.

Problem: Attention Memory Still Too Large

With FSDP-32, we've solved the model memory problem (168 GB → 5.25 GB per GPU). But the 72K × 72K attention matrix still requires 10.37 GB per layer. Even divided by 32 GPUs, we're using most of our memory budget on activations. We need Context Parallelism to split the sequence dimension.

Checkpoint: Remember Tab 3?
The attention matrix is O(n²). At 72K tokens, that's 10.37 GB per layer × 40 layers = 636 GB. FSDP doesn't touch this—it only shards model state (params, grads, optimizer). To attack activation memory, we need to split the SEQUENCE itself. That's Context Parallelism.

Scenario 3: Adding Context Parallelism

Why Sequence Length is the Problem

The attention score matrix scales quadratically: $\mathcal{O}(n^2)$ where $n$ is sequence length. For $n = 72{,}000$:

$$\text{Attention scores} = n \times n \times 2 \text{ bytes} = 72{,}000^2 \times 2 = 10.37 \text{ GB per layer}$$

With CP=16, each GPU handles $n/16 = 4{,}500$ tokens. But attention needs to see ALL tokens, so:

$$\text{Attention scores per GPU} = \frac{n}{16} \times n \times 2 = 4{,}500 \times 72{,}000 \times 2 = 648 \text{ MB}$$

That's a 16× reduction! From 10.37 GB to 648 MB per layer.

Ulysses Sequence Parallelism

Ulysses splits the sequence across 8 GPUs. The key insight: queries can be local (each GPU computes Q for its chunk), but keys and values must be global (need full K, V for correct attention).

Algorithm:

  1. Each GPU projects its local input to $Q_{\text{local}}$, $K_{\text{local}}$, $V_{\text{local}}$
  2. All-gather $K$ and $V$ from all 8 GPUs → get $K_{\text{full}}$, $V_{\text{full}}$
  3. Compute: $\text{softmax}(Q_{\text{local}} K_{\text{full}}^T / \sqrt{d_k}) \, V_{\text{full}}$
  4. Each GPU produces output for its local sequence chunk
Ulysses-8 (splitting 72K tokens across 8 GPUs):

Original: ████████████████████████████████████████ (72,000 tokens)

After Split:
GPU 0: █████  (tokens 0 - 8,999)        → 9K tokens
GPU 1: █████  (tokens 9,000 - 17,999)   → 9K tokens
GPU 2: █████  (tokens 18,000 - 26,999)  → 9K tokens
GPU 3: █████  (tokens 27,000 - 35,999)  → 9K tokens
GPU 4: █████  (tokens 36,000 - 44,999)  → 9K tokens
GPU 5: █████  (tokens 45,000 - 53,999)  → 9K tokens
GPU 6: █████  (tokens 54,000 - 62,999)  → 9K tokens
GPU 7: █████  (tokens 63,000 - 71,999)  → 9K tokens

Each GPU computes:
  Q_local: [9K × 5,120] in bf16  → 90 MB
  K_local: [9K × 5,120] in bf16  → 90 MB
  V_local: [9K × 5,120] in bf16  → 90 MB

All-Gather K, V across 8 GPUs:
  K_full: [72K × 5,120]  → 720 MB
  V_full: [72K × 5,120]  → 720 MB

Attention Computation:
  QK^T: [9K × 5,120] @ [5,120 × 72K] = [9K × 72K]  → 1.3 GB (scores)
  Softmax: apply row-wise
  Output: [9K × 72K] @ [72K × 5,120] = [9K × 5,120]  → 90 MB

Memory per GPU: 90 + 720 + 720 + 1,300 + 90 ≈ 2.9 GB per layer

Much better! From 15.9 GB per layer down to 2.9 GB per layer.

Communication Cost of Ulysses

All-gather of K, V requires transferring the full tensors. Per layer:

$$\text{Data} = 2 \times 72{,}000 \times 5{,}120 \times 2 \text{ bytes} = 1.44 \text{ GB}$$

With NVLink 4.0 bandwidth (900 GB/s within node):

$$\text{Transfer time} = \frac{1.44 \text{ GB}}{900 \text{ GB/s}} = 1.6 \text{ ms}$$

Attention compute time is ~40 ms (FlashAttention-3 on H100), so communication is 1.6/40 = 4% overhead. Acceptable, but we can do better with Ring Attention.

Ring Attention: Overlapping Communication

Ring Attention splits the 8 Ulysses GPUs into rings of 2, reducing communication from all-gather (broadcast pattern) to point-to-point (ring pattern) that overlaps with compute.

Key idea: Don't gather all K, V upfront. Instead, compute attention in blocks while passing K, V around a ring. Use online softmax to maintain numerical stability.

Ring Attention (2-GPU rings within Ulysses-8):

16 GPUs total → 8 independent rings of 2 GPUs each
(8 Ulysses partitions × 2 GPUs per ring = CP-16)

Ring 0: GPU 0  ⟷ GPU 1   (tokens 0-8,999)
Ring 1: GPU 2  ⟷ GPU 3   (tokens 9,000-17,999)
Ring 2: GPU 4  ⟷ GPU 5   (tokens 18,000-26,999)
Ring 3: GPU 6  ⟷ GPU 7   (tokens 27,000-35,999)
Ring 4: GPU 8  ⟷ GPU 9   (tokens 36,000-44,999)
Ring 5: GPU 10 ⟷ GPU 11  (tokens 45,000-53,999)
Ring 6: GPU 12 ⟷ GPU 13  (tokens 54,000-62,999)
Ring 7: GPU 14 ⟷ GPU 15  (tokens 63,000-71,999)

Within each ring (e.g., Ring 0 handling 9K tokens):
  GPU 0 holds: Q₀ (4.5K × 5,120), K₀ (4.5K × 5,120), V₀ (4.5K × 5,120)
  GPU 1 holds: Q₁ (4.5K × 5,120), K₁ (4.5K × 5,120), V₁ (4.5K × 5,120)

Ring Attention Algorithm (2 steps for 2-GPU ring):

Step 0:
  GPU 0: Compute attn(Q₀, K₀, V₀)  → partial output₀⁽⁰⁾
  GPU 1: Compute attn(Q₁, K₁, V₁)  → partial output₁⁽⁰⁾
  [Send K, V to ring neighbor]

Step 1:
  GPU 0: Compute attn(Q₀, K₁, V₁)  → partial output₀⁽¹⁾
         (K₁, V₁ received from GPU 1)
  GPU 1: Compute attn(Q₁, K₀, V₀)  → partial output₁⁽¹⁾
         (K₀, V₀ received from GPU 0)

Final Output (online softmax accumulation):
  GPU 0: Accumulate(output₀⁽⁰⁾, output₀⁽¹⁾)  → final output for tokens 0-4,499
  GPU 1: Accumulate(output₁⁽⁰⁾, output₁⁽¹⁾)  → final output for tokens 4,500-8,999

Communication: Only point-to-point transfers (4.5K × 5,120 × 2 × 2 bytes = 90 MB)
Overlap: While GPU 0 computes, data transfers in background → near-zero overhead!

Online Softmax for Numerical Stability: When accumulating attention outputs from multiple blocks, we use the numerically stable formulation:

$$m_i = \max(m_{i-1}, \max(\text{scores}_i))$$
$$\text{out}_i = \text{out}_{i-1} \cdot e^{m_{i-1} - m_i} + e^{\text{scores}_i - m_i} V_i$$

This maintains correctness while computing attention incrementally without storing the full attention matrix.

Memory with CP=16 (Ulysses-8 + Ring-2)

Each GPU now handles 72,000 ÷ 16 = 4,500 tokens.

Total: ~300 MB per layer

All 40 layers: 12 GB for activations

From 636 GB → 12 GB. 53× reduction!

Checkpoint: Tab 4 in Action
Notice how Ulysses uses All-Gather to collect K,V from all GPUs (each GPU needs to see all keys/values for correct attention). Ring Attention uses point-to-point transfers instead, passing K,V around a ring while overlapping with compute. The communication primitives you learned are exactly what's being used here.

Complete 128-GPU Architecture

Production systems combine all three parallelism strategies. Here's the full setup with overlapping groups:

128 GPUs: CP=16, FSDP=32, DP=4 (with smart overlapping)

DATA PARALLEL (DP=4)
│
├─► 4 identical groups, each processing different batches
│
└─► FSDP GROUP (32 GPUs per DP replica)
    │
    ├─► Model sharded 32 ways (5.25 GB per GPU)
    │
    └─► CONTEXT PARALLEL (CP=16)
        │
        ├─► Ulysses-8: Sequence split into 8 partitions
        │
        │   Partition:  0     1     2     3     4     5     6     7
        │               ↓     ↓     ↓     ↓     ↓     ↓     ↓     ↓
        │              9K    9K    9K    9K    9K    9K    9K    9K tokens
        │
        └─► Ring-2: Each partition handled by 2-GPU ring

            Ring 0: GPU0  ↔ GPU1   → tokens 0-8,999
            Ring 1: GPU2  ↔ GPU3   → tokens 9,000-17,999
            Ring 2: GPU4  ↔ GPU5   → tokens 18,000-26,999
            Ring 3: GPU6  ↔ GPU7   → tokens 27,000-35,999
            Ring 4: GPU8  ↔ GPU9   → tokens 36,000-44,999
            Ring 5: GPU10 ↔ GPU11  → tokens 45,000-53,999
            Ring 6: GPU12 ↔ GPU13  → tokens 54,000-62,999
            Ring 7: GPU14 ↔ GPU15  → tokens 63,000-71,999

            Result: 72K ÷ 16 GPUs = 4.5K tokens per GPU
            Activation memory: ~300 MB per layer

TOTAL: 4 DP groups × 32 GPUs = 128 GPUs

Communication Patterns:
  • CP (Context Parallel):
    - Ulysses: All-gather K,V within 8-GPU group (~1.4 GB, 4% overhead)
    - Ring: P2P transfers overlap with compute (<0.5% overhead)
    → Total CP overhead: <1%

  • FSDP (Fully Sharded Data Parallel):
    - Forward: All-gather params per layer (875 MB per layer)
    - Backward: All-gather params + Reduce-scatter grads
    → Overhead: ~10% (well-optimized with overlap)

  • DP (Data Parallel):
    - All-reduce gradients at end of backward pass
    - Overlapped with final layer backward compute
    → Overhead: <1% (fully overlapped)

Total Communication Overhead: ~11-12% (excellent for 128 GPUs!)

Final Memory Budget Per GPU

Let's calculate the complete memory footprint per GPU in the full 128-GPU setup (CP=16, FSDP=32):

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
            Per-GPU Memory Breakdown (H100 80GB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

STATIC MEMORY (Model State with FSDP-32):
  Model parameters (local shard):          875 MB
  Gradients (local shard):                 875 MB
  Optimizer states (local shard):        3,500 MB
                                         ─────────
  Static subtotal:                       5,250 MB  (5.25 GB)

DYNAMIC MEMORY (Activations with CP-16):
  Per layer (40 layers total):
    Q, K, V: 3 × 4.5K × 5,120 × 2 =        135 MB
    Attention scores: 4.5K × 4.5K × 2 =     40 MB
    FFN intermediate: 4.5K × 13,824 × 2=   122 MB
    Residual + LayerNorm:                    8 MB
                                         ─────────
    Per layer total:                       305 MB

  With gradient checkpointing (save every 5 layers):
    Stored layers: 40/5 = 8 checkpoints
    Memory: 8 × 305 MB =                 2,440 MB

  Working memory for backward:           1,200 MB
                                         ─────────
  Activations subtotal:                  3,640 MB  (3.6 GB)

COMMUNICATION BUFFERS:
  All-gather buffers (FSDP):               800 MB
  Reduce-scatter buffers:                  400 MB
  Ring attention buffers:                  200 MB
                                         ─────────
  Buffers subtotal:                      1,400 MB  (1.4 GB)

BATCH DATA:
  Input latents (4.5K tokens × 16 ch):     180 MB
  Text embeddings (512 tokens):             50 MB
  Timestep embeddings:                      10 MB
                                         ─────────
  Data subtotal:                           240 MB  (0.24 GB)

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
TOTAL USED:                               10.5 GB  (13% of 80 GB)
HEADROOM:                                 69.5 GB  (87% free!)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Result: Substantial headroom allows longer sequences or larger batches!

Additional Memory Optimizations

Other Parallelism Strategies

While FSDP and Context Parallelism are central to our video model training, large-scale LLM training often employs additional parallelism strategies. Understanding these is essential for comprehensive distributed training knowledge.

Tensor Parallelism (TP)

Tensor Parallelism splits individual layers across multiple GPUs. Unlike FSDP (which shards by parameter ownership), TP partitions the actual matrix operations.

Column-wise parallelism (for first projection):

$$Y = XA = X[A_1, A_2] = [XA_1, XA_2]$$

Row-wise parallelism (for second projection):

$$Z = YB = [Y_1, Y_2]\begin{bmatrix}B_1\\B_2\end{bmatrix} = Y_1B_1 + Y_2B_2$$
Tensor Parallelism (TP=2) for Attention:

Input X: [batch × seq × hidden]
                ↓
        ┌───────┴───────┐
        ↓               ↓
     GPU 0           GPU 1
   Q₀,K₀,V₀        Q₁,K₁,V₁
   (heads 0-19)    (heads 20-39)
        ↓               ↓
   Attention₀      Attention₁
        ↓               ↓
        └───────┬───────┘
                ↓ (All-Reduce)
         Output: [batch × seq × hidden]

Memory savings: Each GPU stores 1/TP of attention weights
Communication: All-Reduce after attention output projection

When to use TP vs FSDP:

Pipeline Parallelism (PP)

Pipeline Parallelism assigns different layers to different GPUs, forming a processing pipeline. This is particularly useful when individual layers fit in GPU memory but the full model doesn't.

Pipeline Parallelism (PP=4 stages, 40 layers):

GPU 0: Layers 0-9    ━━━▶
GPU 1: Layers 10-19  ━━━▶
GPU 2: Layers 20-29  ━━━▶
GPU 3: Layers 30-39  ━━━▶

Naive Schedule (high bubble overhead):
┌─────────────────────────────────────────────────┐
│ GPU 0: [F0]          [B0]                      │
│ GPU 1:      [F0]          [B0]                 │
│ GPU 2:           [F0]          [B0]            │
│ GPU 3:                [F0][B0]                 │
│                        ↑ bubble (idle time)    │
└─────────────────────────────────────────────────┘

1F1B Schedule (minimized bubbles):
┌─────────────────────────────────────────────────┐
│ GPU 0: [F0][F1][F2][F3][B0][B1][B2][B3]        │
│ GPU 1:     [F0][F1][F2][B0][F3][B1][B2][B3]    │
│ GPU 2:         [F0][F1][B0][F2][B1][F3][B2][B3]│
│ GPU 3:             [F0][B0][F1][B1][F2][B2]... │
└─────────────────────────────────────────────────┘

F = Forward microbatch, B = Backward microbatch

Pipeline bubble overhead:

$$\text{Bubble fraction} = \frac{p - 1}{m + p - 1}$$

Where $p$ = pipeline stages, $m$ = microbatches. With PP=4 and m=16: bubble = 3/19 ≈ 16%. Interleaved schedules like 1F1B reduce this significantly.

ZeRO: The Foundation of FSDP

Zero Redundancy Optimizer (ZeRO) is the theoretical framework behind FSDP. Understanding its stages clarifies why FSDP works:

ZeRO Stages (14B model, 32 GPUs):

                    │ Per-GPU Memory │ Communication
────────────────────┼────────────────┼──────────────────
Baseline (DDP)      │    168 GB      │ All-Reduce grads
                    │                │
ZeRO-1 (Optimizer)  │    84 GB       │ All-Reduce grads
 Shard optimizer    │ (168-112)/32+56│ + All-Gather optimizer
                    │                │
ZeRO-2 (+ Grads)    │    59 GB       │ Reduce-Scatter grads
 Shard optimizer    │ (168-112-28)/32│ + All-Gather grads
 + gradients        │    + 56        │
                    │                │
ZeRO-3 (+ Params)   │    5.25 GB     │ All-Gather params (fwd)
 Shard everything   │    168/32      │ All-Gather params (bwd)
 = FSDP             │                │ Reduce-Scatter grads
────────────────────┴────────────────┴──────────────────

FSDP = ZeRO-3 implemented in PyTorch

Gradient Accumulation

Gradient Accumulation enables large effective batch sizes when GPU memory is limited. Instead of updating weights after each forward-backward pass, gradients are accumulated over multiple micro-batches:

$$\text{Effective batch size} = \text{micro-batch} \times \text{accumulation steps} \times \text{DP degree}$$
Gradient Accumulation (accumulation_steps=4):

Step 1: forward(batch_0) → backward() → grad += ∇L₀
Step 2: forward(batch_1) → backward() → grad += ∇L₁
Step 3: forward(batch_2) → backward() → grad += ∇L₂
Step 4: forward(batch_3) → backward() → grad += ∇L₃
        ↓
        optimizer.step()  ← update with averaged gradients
        grad = 0          ← reset for next accumulation cycle

Memory: Only micro-batch activations stored (not full batch)
Effective: Same as 4× larger batch, but 4× slower per step

For our video model with CP=16, FSDP=32, DP=4, and accumulation_steps=8:

$$\text{Effective batch} = 1 \text{ video} \times 8 \times 4 = 32 \text{ videos per optimizer step}$$

3D/4D Parallelism: Combining Strategies

Production systems combine multiple parallelism strategies. The key is understanding which strategies are orthogonal and can be composed:

4D Parallelism for 1024-GPU Training:

DP=8 (Data Parallel - 8 replicas for different batches)
│
└─► PP=4 (Pipeline Parallel - 4 stages of layers)
    │
    └─► TP=8 (Tensor Parallel - 8 GPUs split each layer)
        │
        └─► CP=4 (Context Parallel - 4-way sequence split)

Total GPUs = DP × PP × TP × CP = 8 × 4 × 8 × 4 = 1,024

Communication hierarchy (fastest → slowest):
  TP:  NVSwitch within node     (900 GB/s)  ← tightest coupling
  CP:  NVLink across 2 nodes    (400 GB/s)
  PP:  InfiniBand point-to-point (100 GB/s)
  DP:  InfiniBand all-reduce    (100 GB/s)  ← most independent

Fault Tolerance and Checkpointing

Long training runs (days to weeks) inevitably encounter hardware failures. Robust checkpointing is essential:

Checkpoint frequency trade-off:

$$\text{Expected lost work} = \frac{\text{MTBF}}{2} \times \text{checkpoint interval}$$

With Mean Time Between Failures (MTBF) of 24 hours for a 128-GPU cluster, checkpointing every 30 minutes means ~15 minutes of lost work per failure on average.

Data Pipeline Overview

Training on O(10¹²) tokens requires sophisticated data infrastructure. The challenge: random access to millions of video files creates catastrophic I/O contention. The solution: multi-stage pipeline with offline preprocessing and sharding.

Five-Stage Architecture

  1. Raw Storage (122 PB): ProRes 4444 and DNxHR HQX 4K video files stored in MinIO object storage (S3-compatible). Typical file size: 900 GB - 1.35 TB per hour of 4K footage.
  2. Offline Preprocessing (6 months): Dedicated cluster of 50 CPU nodes (AMD EPYC 9654, 96 cores each) performs: video decode (hardware-accelerated), color space conversion (Log/PQ/HLG → Linear RGB), quality filtering (VMAF > 80, discard artifacts), temporal segmentation (32-frame clips with scene coherence), multi-resolution rendering (256px, 480px, 720p).
  3. Sharding (30 PB output): Pack ~3,000 clips into 120 GB TAR archives (WebDataset format). Creates 250,000 shards total. Each shard contains frames as 12-bit PNG (lossless) + JSON metadata (HDR tags, color space, motion statistics).
  4. Distributed Storage (Lustre): 20 storage servers (OSS) + 2 metadata servers (MDS) provide 2 TB/s aggregate bandwidth. Each shard striped across 4 servers with 1 MB stripe size. 100 GbE network per node.
  5. Local NVMe Caching: Each training node has 10 TB NVMe SSD cache. LRU eviction policy achieves 85% hit rate (frequently accessed shards stay cached). 8 CPU workers per GPU perform async prefetch from Lustre while GPUs compute.

Bandwidth Reality Check: Reading 30 PB over 300 training days requires only 3.5 GB/s sustained—trivial compared to 2 TB/s Lustre capacity. The key: data is read slowly over months, not all at once. With 85% cache hit rate, actual Lustre load is only 0.5 GB/s. Data loading is never the bottleneck.

Multi-Stage Training Curriculum

Direct 720p training is inefficient. Progressive resolution scaling enables better convergence:

Stage Resolution Tokens/Sample Batch Size Purpose
1 256px images ~1,024 256 Spatial composition, object semantics
2 192px video
+ 256px images
~11,520 32 images
4 videos
Basic motion, temporal consistency
3 480px video
+ images
~72,000 8 images
2 videos
Fine textures, improved details
4 720p video
+ images
~288,000 4 images
1 video
Production quality, photorealism

Why this works:

Training Configuration Per Stage:

Real-World Results

This infrastructure has been validated in production:

Conclusion

Training 14B parameter video models on 720p sequences requires orchestrating multiple complementary techniques:

The key insight: these techniques are complementary and necessary. FSDP solves model memory. Context Parallelism solves activation memory. Tensor and Pipeline Parallelism enable further scaling. FlashAttention solves compute efficiency. The data pipeline solves I/O. Remove any piece and the system breaks.

This infrastructure represents years of engineering across distributed systems, numerical optimization, and machine learning. But the core principles are timeless: partition what doesn't fit, coordinate through communication, overlap wherever possible. As we push toward photorealistic, high-resolution, long-form video generation, these patterns will continue to enable the impossible.