Training Infrastructure — GPUs, Distributed Training & Optimisation
28 min
A model that trains in 8 hours when it could train in 2 hours is a 4x productivity tax. Multiply that by dozens of experiments and you are losing weeks of iteration time. Training infrastructure is not about raw hardware — it is about using hardware efficiently. This lesson covers every major technique for maximising GPU utilisation and making training infrastructure reliable enough to run unattended in production.
GPU Memory Hierarchy
Understanding GPU memory is prerequisite to understanding every optimisation in this lesson.
VRAM (HBM): the on-chip high-bandwidth memory on the GPU. Everything that participates in training must fit in VRAM: model weights, gradients, optimizer states (Adam keeps two momentum tensors per parameter — 2x the weight size), and activations from the forward pass. An 80GB A100 sounds like a lot until you store a 7B parameter model in float32 (7B × 4 bytes = 28GB just for weights, before gradients and optimizer state).
Bandwidth vs compute: most operations in deep learning are bandwidth-limited, not compute-limited. Moving 1GB of activations from VRAM to the CUDA cores takes time proportional to memory bandwidth (HBM3 on A100: 2 TB/s). Operations that do very little compute per byte read (element-wise activations, layer norms) are bandwidth-bound. Operations that do a lot of compute per byte (large matrix multiplications) are compute-bound. Mixed precision and fused kernels help by reducing the bandwidth needed per operation.
Mixed Precision Training
Mixed precision training runs the forward and backward passes in FP16 or BF16, while keeping a master copy of weights in FP32 for the optimizer update. This halves memory usage for activations and nearly doubles throughput on Tensor Cores (which natively accelerate FP16/BF16 matmuls).
FP16 vs BF16: BF16 has the same 8-bit exponent as FP32 (so the same numeric range), but only 7 mantissa bits (vs 23 in FP32 and 10 in FP16). FP16 has more precision but can overflow on large activation values. BF16 is generally preferred for training; FP16 is more common for inference where the range is predictable.
python
import torchimport torch.nn as nnfrom torch.cuda.amp import autocast, GradScalerdef train_one_epoch_mixed_precision( model: nn.Module, loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, scaler: GradScaler, device: torch.device, gradient_clip_norm: float = 1.0,) -> float: model.train() total_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) # set_to_none=True saves memory # Forward pass in FP16/BF16 with autocast(device_type="cuda", dtype=torch.float16): outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass: scaler prevents underflow in FP16 gradients scaler.scale(loss).backward() # Unscale before gradient clipping (required before clip_grad_norm_) scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm) # Optimizer step: only updates if gradients are finite (no NaN/inf) scaler.step(optimizer) scaler.update() total_loss += loss.item() return total_loss / len(loader)# Initialise scaler once, pass it through the training loopscaler = GradScaler()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Gradient Accumulation
When your batch size is constrained by VRAM, gradient accumulation simulates a larger effective batch by accumulating gradients over multiple forward/backward passes before stepping the optimizer.
python
def train_with_gradient_accumulation( model: nn.Module, loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, scaler: GradScaler, device: torch.device, accumulation_steps: int = 8, # effective_batch = batch_size × 8) -> float: model.train() total_loss = 0.0 optimizer.zero_grad(set_to_none=True) for batch_idx, (inputs, targets) in enumerate(loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) with autocast(device_type="cuda", dtype=torch.float16): outputs = model(inputs) # Normalise loss so gradient magnitude is equivalent to # computing it on the full effective batch at once loss = criterion(outputs, targets) / accumulation_steps scaler.scale(loss).backward() # Only step after accumulation_steps mini-batches if (batch_idx + 1) % accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) total_loss += loss.item() * accumulation_steps return total_loss / len(loader)
DataLoader Optimisation
The DataLoader is frequently the bottleneck in training pipelines. A GPU that is starved of data will show near-zero utilisation between batches.
python
import multiprocessingfrom torch.utils.data import DataLoader, Datasetdef make_optimised_dataloader( dataset: Dataset, batch_size: int, training: bool = True,) -> DataLoader: """ Production-quality DataLoader with all performance settings. """ num_workers = min(8, multiprocessing.cpu_count()) return DataLoader( dataset, batch_size=batch_size, shuffle=training, num_workers=num_workers, # parallel data loading on CPU pin_memory=True, # allocate batches in pinned (page-locked) # memory for faster CPU→GPU (H2D) transfer persistent_workers=True, # keep worker processes alive between epochs # (avoid fork overhead on every epoch) prefetch_factor=2, # each worker pre-fetches 2 batches ahead drop_last=training, # drop incomplete final batch during training )# Tip: profile DataLoader vs GPU time with a simple benchmarkdef benchmark_dataloader(loader: DataLoader, n_batches: int = 50) -> None: import time start = time.perf_counter() for i, batch in enumerate(loader): if i >= n_batches: break elapsed = time.perf_counter() - start print(f"DataLoader: {n_batches} batches in {elapsed:.2f}s " f"({elapsed/n_batches*1000:.1f}ms/batch)")
Distributed Training
When a single GPU is not enough — either because the model is too large or because you need to train faster — you move to distributed training.
DataParallel (single node, easy but inefficient)
DataParallel wraps a model and splits each batch across multiple GPUs on a single machine. The primary GPU aggregates gradients. Simple to use but suboptimal: the primary GPU is a bottleneck, and communication goes through CPU memory.
python
# DataParallel: wraps model, splits batch across GPUsmodel = nn.DataParallel(model, device_ids=[0, 1, 2, 3])model = model.cuda()# Inference and training code is unchanged
DistributedDataParallel (the production choice)
DDP spawns one process per GPU. Each process has its own model replica. Gradients are synchronised with an AllReduce operation (using NCCL on NVIDIA hardware). No primary GPU bottleneck.
python
import torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPfrom torch.utils.data.distributed import DistributedSamplerimport osdef setup_ddp(rank: int, world_size: int) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank)def cleanup_ddp() -> None: dist.destroy_process_group()def train_ddp(rank: int, world_size: int, model_cls, dataset) -> None: setup_ddp(rank, world_size) device = torch.device(f"cuda:{rank}") model = model_cls().to(device) model = DDP(model, device_ids=[rank]) # wrap with DDP # Each process gets a non-overlapping slice of the data sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank ) loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4, pin_memory=True) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) criterion = nn.CrossEntropyLoss() scaler = GradScaler() for epoch in range(10): sampler.set_epoch(epoch) # reshuffle per epoch train_one_epoch_mixed_precision( model, loader, optimizer, criterion, scaler, device ) cleanup_ddp()# Launch with: torch.multiprocessing.spawn(train_ddp, args=(4, MyModel, dataset), nprocs=4)
FSDP (Fully Sharded Data Parallel)
FSDP shards model parameters, gradients, and optimizer state across GPUs. This allows training models that are too large to fit on a single GPU. PyTorch FSDP is the open-source equivalent of ZeRO-3 (DeepSpeed). Use it when a DDP model exceeds single-GPU VRAM.
python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDPfrom torch.distributed.fsdp.wrap import size_based_auto_wrap_policyimport functools# Auto-wrap: shard layers with more than 1M parametersmy_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=1_000_000,)model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
Gradient Checkpointing
Gradient checkpointing trades compute for memory. During the forward pass, activations are discarded instead of stored. During the backward pass, they are recomputed on the fly as needed. This reduces activation memory from O(layers) to O(sqrt(layers)), at the cost of ~30% more compute.
python
from torch.utils.checkpoint import checkpoint_sequentialdef forward_with_checkpointing(model_layers: nn.Sequential, x: torch.Tensor) -> torch.Tensor: """ Recompute activations during backward pass instead of storing them. Reduces memory by ~50-70% for large models. segments controls how many checkpointed segments to split into. """ segments = 4 return checkpoint_sequential(model_layers, segments, x)# For transformer blocks: wrap individual blocksclass CheckpointedTransformerBlock(nn.Module): def __init__(self, block: nn.Module) -> None: super().__init__() self.block = block def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.utils.checkpoint.checkpoint(self.block, x)
PyTorch Profiler
Profile before optimising. The profiler tells you exactly where time is spent and where VRAM is consumed.
python
from torch.profiler import profile, record_function, ProfilerActivitydef profile_training_step( model: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, optimizer: torch.optim.Optimizer, criterion: nn.Module,) -> None: with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True, ) as prof: with record_function("forward"): outputs = model(inputs) loss = criterion(outputs, targets) with record_function("backward"): loss.backward() with record_function("optimizer"): optimizer.step() optimizer.zero_grad() # Print top operations by CUDA time print(prof.key_averages().table( sort_by="cuda_time_total", row_limit=10 )) # Export to TensorBoard or Chrome trace prof.export_chrome_trace("trace.json")
Training on Spot Instances
Spot (preemptible) instances cost 60-90% less than on-demand but can be interrupted with 2 minutes notice. Reliable spot training requires checkpoint-on-interrupt and resume-from-checkpoint.
GPU memory holds weights + gradients + optimizer state + activations; mixed precision halves activation memory and nearly doubles throughput via Tensor Cores.
Use BF16 for training (same range as FP32, avoids overflow) and FP16 for inference where range is predictable; always use GradScaler with FP16 to prevent gradient underflow.
Gradient accumulation allows large effective batch sizes without large VRAM; divide the loss by accumulation_steps to keep gradient magnitude consistent.
Set pin_memory=True, num_workers=cpu_count(), persistent_workers=True, and prefetch_factor=2 on every training DataLoader; DataLoader starvation wastes expensive GPU time.
Prefer DDP over DataParallel for multi-GPU training — DDP has no primary GPU bottleneck and scales linearly. Use FSDP when the model does not fit on a single GPU.
Gradient checkpointing reduces activation memory by 50-70% at a 30% compute cost; always enable it when VRAM is the constraint.
Profile before optimising: the PyTorch profiler will tell you whether you are CPU-bound, memory-bandwidth-bound, or actually compute-bound.
Checkpoint every few hundred steps and register SIGTERM handlers when training on spot instances; a 2-minute warning is enough to save state if your handler is in place.