GadaaLabs
Machine Learning Engineering
Lesson 3

Training Loops in PyTorch

18 min

PyTorch gives you full control of the training loop. That flexibility is powerful, but it also means every mistake is yours to make: forgetting to zero gradients, applying the wrong loss for the task type, or letting exploding gradients silently corrupt your weights. This lesson builds a production-quality training loop from the ground up.

Choosing the Right Loss Function

| Task | Loss function | PyTorch class | |---|---|---| | Binary classification | Binary cross-entropy | nn.BCEWithLogitsLoss | | Multi-class classification | Cross-entropy | nn.CrossEntropyLoss | | Regression | Mean squared error | nn.MSELoss | | Regression (robust to outliers) | Huber / smooth L1 | nn.HuberLoss | | Ranking / contrastive | Triplet margin | nn.TripletMarginLoss |

Always prefer BCEWithLogitsLoss over BCELoss(sigmoid(logits)) — it combines the sigmoid and the loss in a numerically stable way.

Optimizer and Scheduler Setup

python
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

model = MyModel()
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

AdamW decouples weight decay from the gradient update, which is more correct than L2 regularisation via the standard Adam implementation. label_smoothing=0.1 prevents the model from becoming overconfident on training labels.

The Training Loop

python
def train_epoch(model, loader, optimizer, criterion, device, max_grad_norm=1.0):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for batch_idx, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()           # 1. clear stale gradients
        logits = model(inputs)          # 2. forward pass
        loss = criterion(logits, labels)# 3. compute loss
        loss.backward()                 # 4. backprop

        # 5. clip gradients before optimizer step
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)

        optimizer.step()                # 6. update weights

        total_loss += loss.item() * inputs.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += inputs.size(0)

    return total_loss / total_samples, total_correct / total_samples

Step 5 — gradient clipping — is the most commonly skipped step. Without it, a single bad batch can produce gradients with a norm of 1e6, making the parameter update massive and potentially irrecoverable.

Validation and Early Stopping

python
def validate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss = criterion(logits, labels)
            total_loss += loss.item() * inputs.size(0)
            total_correct += (logits.argmax(1) == labels).sum().item()
            total_samples += inputs.size(0)

    return total_loss / total_samples, total_correct / total_samples

best_val_loss = float("inf")
patience, patience_counter = 5, 0

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)
    scheduler.step()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

Reproducibility Seeds

python
import random, numpy as np

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

Summary

  • Match the loss function to the task type; prefer BCEWithLogitsLoss over manual sigmoid + BCE for numerical stability.
  • Use AdamW (not Adam) with explicit weight decay, and pair it with a cosine or step LR scheduler.
  • Structure the training loop as: zero gradients → forward → loss → backward → clip gradients → step.
  • Clip gradient norm to 1.0 by default — it is free insurance against exploding gradients.
  • Save the best checkpoint based on validation loss and implement patience-based early stopping.
  • Fix all random seeds before the first forward pass; also set cudnn.deterministic = True.