⏱ 9 min read 📊 Intermediate 🗓 Updated Jan 2025

🎯 The Learning Problem

A neural network starts as a random function. "Training" means systematically adjusting its millions of parameters so that its output becomes increasingly accurate on the training data — and generalises to unseen data. The entire process reduces to a single question: how do we change the weights to reduce the error?

What Training Means

Training is an optimisation problem: find the weight values θ that minimise a loss function L(θ) measuring prediction error over the training data. If we could solve this analytically (set the derivative to zero and solve for θ), we would — but neural networks have billions of parameters and non-convex loss landscapes, making closed-form solutions impossible.

  • Goal: θ* = argmin L(θ) over training dataset
  • Loss landscape: think of a mountainous terrain; we want the lowest valley
  • Billions of parameters = billions of dimensions in the loss landscape
  • Not convex: multiple local minima, saddle points, plateaus

The Forward Pass

Before we can compute how wrong the network is, we must run it: feed an input through every layer sequentially, computing activations at each step. This is the forward pass. All intermediate activations are stored in memory because backpropagation will need them to compute gradients.

  • Input → Layer 1 → Layer 2 → … → Output (prediction)
  • Compare prediction to ground truth label using loss function
  • Memory cost: O(number of layers × batch size × layer width)
  • Gradient checkpointing trades compute for memory in very deep models

Why We Can't Just Solve Analytically

For a single linear regression, we can compute the optimal weights in one matrix operation: θ = (XᵀX)⁻¹Xᵀy. But for a 7-billion-parameter neural network with ReLU activations, softmax outputs, and layer normalisation, the loss function is a non-convex, non-smooth function of billions of variables. No closed-form solution exists. We must use iterative gradient-based methods, updating weights step by step in the direction that reduces the loss.

📉 Loss Functions

A loss function (also called a cost function or objective function) quantifies the discrepancy between the network's prediction and the true label. The choice of loss function encodes what "error" means for your specific task.

Loss Function Formula Task Notes
Mean Squared Error (MSE) L = (1/n)Σ(yᵢ − ŷᵢ)² Regression Penalises large errors quadratically. Sensitive to outliers. Differentiable everywhere. Standard for regression tasks.
Mean Absolute Error (MAE) L = (1/n)Σ|yᵢ − ŷᵢ| Regression Linear penalty — more robust to outliers than MSE. Non-differentiable at zero, but subgradient exists. Used when large outliers are common.
Binary Cross-Entropy L = −[y·log(ŷ) + (1−y)·log(1−ŷ)] Binary classification Used with sigmoid output. Penalises confident wrong predictions very heavily (log → −∞ as ŷ → 0 when y=1). Derived from maximum likelihood under Bernoulli distribution.
Categorical Cross-Entropy L = −Σᵢ yᵢ·log(ŷᵢ) Multi-class classification Used with softmax output. For one-hot labels, reduces to −log(ŷ_correct_class). Standard loss for image classification, NLP token prediction. Also called Negative Log-Likelihood.
Huber Loss MSE if |e|≤δ, else δ·(|e|−δ/2) Regression with outliers Quadratic for small errors (like MSE), linear for large errors (like MAE). Balances both regimes. δ is a tunable threshold. Used in object detection (bounding box regression).
KL Divergence D_KL(P||Q) = ΣP·log(P/Q) Distribution matching Measures how much distribution Q diverges from reference P. Used in VAEs, knowledge distillation, reinforcement learning from human feedback (RLHF). Not symmetric.

Choosing the Right Loss Function

The loss function must match the output type and task. Using MSE for classification ignores the probabilistic interpretation of softmax outputs and provides poor gradient signal for wrong-class probabilities. Using cross-entropy for regression is nonsensical. The loss also encodes what kinds of errors matter: if false negatives are catastrophic in a medical context, consider weighted cross-entropy or focal loss. Always verify that the loss function's assumptions match your data distribution.

⛰️ Gradient Descent

The gradient ∇L(θ) points in the direction of steepest increase of the loss. By stepping in the opposite direction — the negative gradient — we descend the loss landscape toward lower loss values. This is gradient descent.

# The gradient descent update rule
θ ← θ − α · ∇L(θ)

# Where:
#   θ  = all model parameters (weights and biases)
#   α  = learning rate (step size, typically 1e-3 to 1e-5)
#   ∇L = gradient of loss with respect to θ
#   −  = minus: move AGAINST the gradient (downhill)

# Example: scalar parameter with loss L = (w − 3)²
# Minimum at w = 3, gradient = dL/dw = 2(w − 3)
#
# w₀ = 0.0
# gradient = 2(0 − 3) = −6
# w₁ = 0.0 − 0.1 × (−6) = 0.6
#
# w₁ = 0.6, gradient = 2(0.6 − 3) = −4.8
# w₂ = 0.6 − 0.1 × (−4.8) = 1.08
#
# Converging toward w = 3 with each step.
      
Variant How It Works Pros Cons
Batch GD Compute gradient over entire dataset, then update once Exact gradient; smooth convergence; stable Memory intensive; one update per epoch; slow on large datasets
Stochastic GD (SGD) Compute gradient from a single random example, update immediately Very fast updates; noisy oscillations can escape local minima; low memory High variance; noisy loss curve; may not converge precisely
Mini-Batch GD Compute gradient over a small batch (e.g. 32, 256), then update Balance between noise and stability; GPU-parallelisable; standard in practice Batch size is another hyperparameter; some noise in gradient estimates

Saddle Points vs Local Minima in High Dimensions

In a 1D loss landscape, local minima are common and concerning. In a billion-dimensional space, a point where every dimension points upward simultaneously is extraordinarily rare. Most "stuck" points in deep learning are saddle points — where some dimensions point up and others point down. These are escaped naturally by the noise in mini-batch gradient estimates. This partially explains why deep networks with gradient descent work better in practice than 1D intuitions would suggest.

🔁 Backpropagation

Backpropagation (Rumelhart, Hinton & Williams, 1986) is the algorithm for efficiently computing ∇L(θ) — the gradient of the loss with respect to every parameter in the network. It applies the chain rule of calculus systematically, working backwards from the output layer to the input layer.

# Backpropagation — chain rule applied to a 2-layer network
#
# Architecture: x → [W1, b1] → ReLU → [W2, b2] → sigmoid → ŷ → BCE loss
#
# Forward pass (store all intermediate values):
z1 = W1·x + b1          # pre-activation, layer 1
a1 = ReLU(z1)           # activation, layer 1
z2 = W2·a1 + b2         # pre-activation, layer 2
ŷ  = sigmoid(z2)        # prediction
L  = -[y·log(ŷ) + (1−y)·log(1−ŷ)]   # binary cross-entropy loss
#
# Backward pass (chain rule, right to left):
#
# ∂L/∂ŷ  = (ŷ − y) / [ŷ·(1−ŷ)]         # derivative of BCE wrt ŷ
# ∂ŷ/∂z2 = ŷ·(1−ŷ)                       # derivative of sigmoid wrt z2
# ∂L/∂z2 = ∂L/∂ŷ · ∂ŷ/∂z2 = ŷ − y      # simplified: error signal
#
# ∂L/∂W2 = ∂L/∂z2 · ∂z2/∂W2 = (ŷ−y) · a1ᵀ
# ∂L/∂b2 = ∂L/∂z2 = ŷ − y
#
# ∂L/∂a1 = ∂L/∂z2 · ∂z2/∂a1 = (ŷ−y) · W2ᵀ
# ∂a1/∂z1 = ReLU'(z1) = 1 if z1>0 else 0   # ReLU gradient
# ∂L/∂z1 = ∂L/∂a1 · ∂a1/∂z1              # element-wise
#
# ∂L/∂W1 = ∂L/∂z1 · xᵀ
# ∂L/∂b1 = ∂L/∂z1
#
# Update: W1 -= α·∂L/∂W1,  W2 -= α·∂L/∂W2, etc.
# Modern frameworks (PyTorch, JAX) compute all this automatically via autograd.
      

The Chain Rule Intuition

The chain rule states: if y depends on u, and u depends on x, then dy/dx = dy/du × du/dx. Backprop applies this recursively. The gradient at each layer is computed by multiplying the gradient coming from the next layer by the local gradient (derivative of the layer's own transformation). This decomposes what would be a computationally infeasible direct computation into a sequence of manageable local operations.

Local gradients × upstream gradients

Automatic Differentiation

Modern frameworks (PyTorch, JAX, TensorFlow) implement reverse-mode automatic differentiation ("autograd"). During the forward pass, the framework builds a computational graph. The backward pass traverses this graph in reverse, computing exact gradients without any manual derivation. This is not symbolic differentiation or numerical approximation — it is exact, efficient automatic calculus.

Exact gradientsPyTorch .backward()

Vanishing and Exploding Gradients

During backprop, gradients are multiplied layer by layer. If weights are initialised such that activations consistently shrink (e.g. sigmoid outputs in [0,1] multiplied many times), gradients reaching early layers can be effectively zero — the vanishing gradient problem. Conversely, if weights are large, gradients can grow exponentially — exploding gradients. Solutions: careful initialisation (He for ReLU, Glorot/Xavier for tanh/sigmoid), batch normalisation, residual/skip connections, and gradient clipping.

🚀 Optimisers

Plain SGD applies the same learning rate to all parameters. In practice, some parameters need large updates (if they are rarely activated), others need small ones (if they are frequently updated). Adaptive optimisers address this by maintaining per-parameter learning rates that are adjusted based on gradient history.

Optimiser Key Idea Hyperparameters Pros / Cons
SGD + Momentum Accumulate an exponentially weighted moving average of past gradients ("velocity"). This dampens oscillations and accelerates convergence in the right direction. lr, momentum (typically 0.9) Pros: simple, well-understood, good generalisation with tuning. Cons: sensitive to lr; requires careful scheduling. Still used for large-scale vision training.
AdaGrad Divides the learning rate by the square root of the cumulative sum of squared past gradients. Parameters with many large gradients get smaller effective lr; sparse parameters get larger lr. lr, epsilon (numerical stability) Pros: good for sparse data (NLP bag-of-words). Cons: accumulated sum monotonically increases → learning rate eventually → 0 (learning stops). Rarely used today.
RMSProp Fixes AdaGrad's decaying lr by using an exponentially weighted moving average of squared gradients instead of the full cumulative sum. lr, decay (typically 0.9), epsilon Pros: works well for RNNs; non-monotonic lr decay. Cons: still requires lr tuning. Proposed by Hinton in a Coursera lecture (unpublished).
Adam Combines momentum (first moment: mean of gradients) and RMSProp (second moment: mean of squared gradients). Both moments are bias-corrected for the first few steps when they are underestimated. lr (1e-3), β₁ (0.9), β₂ (0.999), epsilon (1e-8) Pros: fast convergence; robust to lr choice; de facto standard. Cons: can generalise slightly worse than well-tuned SGD+momentum on some tasks.
AdamW Adam with decoupled weight decay: L2 regularisation is applied directly to weights rather than being folded into the gradient (which distorts the adaptive scaling). Better regularisation. lr (1e-4), β₁, β₂, weight_decay (0.01-0.1) Pros: the standard for Transformer training; better generalisation than Adam. Nearly universal in LLM pre-training (GPT, BERT, LLaMA all use AdamW).

Adam is the Default Starting Point

When in doubt, start with AdamW, learning rate 1e-4, default β values. It will almost certainly work reasonably well out of the box. If you are training a large vision model from scratch (ImageNet, etc.), SGD + momentum + cosine annealing schedule is worth the extra tuning effort as it often achieves 1-2% better top-1 accuracy. For fine-tuning pre-trained models or training Transformers, AdamW is typically the right choice. The framework default values (PyTorch, Keras) are sensible starting points.

Learning Rate Schedulers

The learning rate typically needs to decrease over training. A too-large lr in late training causes the parameters to oscillate around minima without settling. Common schedules:

  • Step decay: reduce lr by factor γ every N epochs
  • Cosine annealing: smooth decrease following a cosine curve; widely used with SGD
  • Warmup + decay: linear warmup for K steps, then decay; standard for Transformers
  • OneCycleLR: increases lr to max then decreases; fast convergence (super-convergence)
  • ReduceLROnPlateau: reduce when validation loss stops improving

Regularisation During Training

Preventing overfitting is as important as learning quickly. Key techniques used alongside backprop:

  • L2 weight decay (AdamW): penalises large weights, keeps parameters small and general
  • Dropout: randomly zero activations during training; forces redundant representations
  • Batch Normalisation: normalises activations; mild regularisation effect
  • Data augmentation: synthetic training variety; not technically a loss term but acts as regulariser
  • Early stopping: halt training when validation loss stops improving