PyTorch gives you full control of the training loop. That flexibility is powerful, but it also means every mistake is yours to make: forgetting to zero gradients, applying the wrong loss for the task type, or letting exploding gradients silently corrupt your weights. This lesson builds a production-quality training loop from the ground up.
Choosing the Right Loss Function
| Task | Loss function | PyTorch class |
|---|---|---|
| Binary classification | Binary cross-entropy | nn.BCEWithLogitsLoss |
| Multi-class classification | Cross-entropy | nn.CrossEntropyLoss |
| Regression | Mean squared error | nn.MSELoss |
| Regression (robust to outliers) | Huber / smooth L1 | nn.HuberLoss |
| Ranking / contrastive | Triplet margin | nn.TripletMarginLoss |
Always prefer BCEWithLogitsLoss over BCELoss(sigmoid(logits)) — it combines the sigmoid and the loss in a numerically stable way.
Optimizer and Scheduler Setup
python import torch
import torch . nn as nn
from torch . optim import AdamW
from torch . optim . lr_scheduler import CosineAnnealingLR
model = MyModel ()
optimizer = AdamW ( model . parameters (), lr = 3e-4 , weight_decay = 1e-2 )
scheduler = CosineAnnealingLR ( optimizer , T_max = NUM_EPOCHS , eta_min = 1e-6 )
criterion = nn . CrossEntropyLoss ( label_smoothing = 0.1 )
AdamW decouples weight decay from the gradient update, which is more correct than L2 regularisation via the standard Adam implementation. label_smoothing=0.1 prevents the model from becoming overconfident on training labels.
The Training Loop
python def train_epoch ( model , loader , optimizer , criterion , device , max_grad_norm = 1.0 ):
model . train ()
total_loss , total_correct , total_samples = 0.0 , 0 , 0
for batch_idx , ( inputs , labels ) in enumerate ( loader ):
inputs , labels = inputs . to ( device ), labels . to ( device )
optimizer . zero_grad () # 1. clear stale gradients
logits = model ( inputs ) # 2. forward pass
loss = criterion ( logits , labels ) # 3. compute loss
loss . backward () # 4. backprop
# 5. clip gradients before optimizer step
nn . utils . clip_grad_norm_ ( model . parameters (), max_norm = max_grad_norm )
optimizer . step () # 6. update weights
total_loss += loss . item () * inputs . size ( 0 )
preds = logits . argmax ( dim = 1 )
total_correct += ( preds == labels ). sum (). item ()
total_samples += inputs . size ( 0 )
return total_loss / total_samples , total_correct / total_samples
Step 5 — gradient clipping — is the most commonly skipped step. Without it, a single bad batch can produce gradients with a norm of 1e6, making the parameter update massive and potentially irrecoverable.
Validation and Early Stopping
python def validate ( model , loader , criterion , device ):
model . eval ()
total_loss , total_correct , total_samples = 0.0 , 0 , 0
with torch . no_grad ():
for inputs , labels in loader :
inputs , labels = inputs . to ( device ), labels . to ( device )
logits = model ( inputs )
loss = criterion ( logits , labels )
total_loss += loss . item () * inputs . size ( 0 )
total_correct += ( logits . argmax ( 1 ) == labels ). sum (). item ()
total_samples += inputs . size ( 0 )
return total_loss / total_samples , total_correct / total_samples
best_val_loss = float ( " inf " )
patience , patience_counter = 5 , 0
for epoch in range ( NUM_EPOCHS ):
train_loss , train_acc = train_epoch ( model , train_loader , optimizer , criterion , device )
val_loss , val_acc = validate ( model , val_loader , criterion , device )
scheduler . step ()
if val_loss < best_val_loss :
best_val_loss = val_loss
torch . save ( model . state_dict (), " best_model.pt " )
patience_counter = 0
else :
patience_counter += 1
if patience_counter >= patience :
print ( f "Early stopping at epoch { epoch } " )
break
Reproducibility Seeds
python import random , numpy as np
def set_seed ( seed : int = 42 ):
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed_all ( seed )
torch . backends . cudnn . deterministic = True
torch . backends . cudnn . benchmark = False
set_seed ( 42 )
Summary
Match the loss function to the task type; prefer BCEWithLogitsLoss over manual sigmoid + BCE for numerical stability.
Use AdamW (not Adam) with explicit weight decay, and pair it with a cosine or step LR scheduler.
Structure the training loop as: zero gradients → forward → loss → backward → clip gradients → step.
Clip gradient norm to 1.0 by default — it is free insurance against exploding gradients.
Save the best checkpoint based on validation loss and implement patience-based early stopping.
Fix all random seeds before the first forward pass; also set cudnn.deterministic = True.