🎯 The Learning Problem
A neural network starts as a random function. "Training" means systematically adjusting its millions of parameters so that its output becomes increasingly accurate on the training data — and generalises to unseen data. The entire process reduces to a single question: how do we change the weights to reduce the error?
What Training Means
Training is an optimisation problem: find the weight values θ that minimise a loss function L(θ) measuring prediction error over the training data. If we could solve this analytically (set the derivative to zero and solve for θ), we would — but neural networks have billions of parameters and non-convex loss landscapes, making closed-form solutions impossible.
- Goal: θ* = argmin L(θ) over training dataset
- Loss landscape: think of a mountainous terrain; we want the lowest valley
- Billions of parameters = billions of dimensions in the loss landscape
- Not convex: multiple local minima, saddle points, plateaus
The Forward Pass
Before we can compute how wrong the network is, we must run it: feed an input through every layer sequentially, computing activations at each step. This is the forward pass. All intermediate activations are stored in memory because backpropagation will need them to compute gradients.
- Input → Layer 1 → Layer 2 → … → Output (prediction)
- Compare prediction to ground truth label using loss function
- Memory cost: O(number of layers × batch size × layer width)
- Gradient checkpointing trades compute for memory in very deep models
Why We Can't Just Solve Analytically
For a single linear regression, we can compute the optimal weights in one matrix operation: θ = (XᵀX)⁻¹Xᵀy. But for a 7-billion-parameter neural network with ReLU activations, softmax outputs, and layer normalisation, the loss function is a non-convex, non-smooth function of billions of variables. No closed-form solution exists. We must use iterative gradient-based methods, updating weights step by step in the direction that reduces the loss.
📉 Loss Functions
A loss function (also called a cost function or objective function) quantifies the discrepancy between the network's prediction and the true label. The choice of loss function encodes what "error" means for your specific task.
| Loss Function | Formula | Task | Notes |
|---|---|---|---|
| Mean Squared Error (MSE) | L = (1/n)Σ(yᵢ − ŷᵢ)² |
Regression | Penalises large errors quadratically. Sensitive to outliers. Differentiable everywhere. Standard for regression tasks. |
| Mean Absolute Error (MAE) | L = (1/n)Σ|yᵢ − ŷᵢ| |
Regression | Linear penalty — more robust to outliers than MSE. Non-differentiable at zero, but subgradient exists. Used when large outliers are common. |
| Binary Cross-Entropy | L = −[y·log(ŷ) + (1−y)·log(1−ŷ)] |
Binary classification | Used with sigmoid output. Penalises confident wrong predictions very heavily (log → −∞ as ŷ → 0 when y=1). Derived from maximum likelihood under Bernoulli distribution. |
| Categorical Cross-Entropy | L = −Σᵢ yᵢ·log(ŷᵢ) |
Multi-class classification | Used with softmax output. For one-hot labels, reduces to −log(ŷ_correct_class). Standard loss for image classification, NLP token prediction. Also called Negative Log-Likelihood. |
| Huber Loss | MSE if |e|≤δ, else δ·(|e|−δ/2) |
Regression with outliers | Quadratic for small errors (like MSE), linear for large errors (like MAE). Balances both regimes. δ is a tunable threshold. Used in object detection (bounding box regression). |
| KL Divergence | D_KL(P||Q) = ΣP·log(P/Q) |
Distribution matching | Measures how much distribution Q diverges from reference P. Used in VAEs, knowledge distillation, reinforcement learning from human feedback (RLHF). Not symmetric. |
Choosing the Right Loss Function
The loss function must match the output type and task. Using MSE for classification ignores the probabilistic interpretation of softmax outputs and provides poor gradient signal for wrong-class probabilities. Using cross-entropy for regression is nonsensical. The loss also encodes what kinds of errors matter: if false negatives are catastrophic in a medical context, consider weighted cross-entropy or focal loss. Always verify that the loss function's assumptions match your data distribution.
⛰️ Gradient Descent
The gradient ∇L(θ) points in the direction of steepest increase of the loss. By stepping in the opposite direction — the negative gradient — we descend the loss landscape toward lower loss values. This is gradient descent.
# The gradient descent update rule
θ ← θ − α · ∇L(θ)
# Where:
# θ = all model parameters (weights and biases)
# α = learning rate (step size, typically 1e-3 to 1e-5)
# ∇L = gradient of loss with respect to θ
# − = minus: move AGAINST the gradient (downhill)
# Example: scalar parameter with loss L = (w − 3)²
# Minimum at w = 3, gradient = dL/dw = 2(w − 3)
#
# w₀ = 0.0
# gradient = 2(0 − 3) = −6
# w₁ = 0.0 − 0.1 × (−6) = 0.6
#
# w₁ = 0.6, gradient = 2(0.6 − 3) = −4.8
# w₂ = 0.6 − 0.1 × (−4.8) = 1.08
#
# Converging toward w = 3 with each step.
| Variant | How It Works | Pros | Cons |
|---|---|---|---|
| Batch GD | Compute gradient over entire dataset, then update once | Exact gradient; smooth convergence; stable | Memory intensive; one update per epoch; slow on large datasets |
| Stochastic GD (SGD) | Compute gradient from a single random example, update immediately | Very fast updates; noisy oscillations can escape local minima; low memory | High variance; noisy loss curve; may not converge precisely |
| Mini-Batch GD | Compute gradient over a small batch (e.g. 32, 256), then update | Balance between noise and stability; GPU-parallelisable; standard in practice | Batch size is another hyperparameter; some noise in gradient estimates |
Saddle Points vs Local Minima in High Dimensions
In a 1D loss landscape, local minima are common and concerning. In a billion-dimensional space, a point where every dimension points upward simultaneously is extraordinarily rare. Most "stuck" points in deep learning are saddle points — where some dimensions point up and others point down. These are escaped naturally by the noise in mini-batch gradient estimates. This partially explains why deep networks with gradient descent work better in practice than 1D intuitions would suggest.
🔁 Backpropagation
Backpropagation (Rumelhart, Hinton & Williams, 1986) is the algorithm for efficiently computing ∇L(θ) — the gradient of the loss with respect to every parameter in the network. It applies the chain rule of calculus systematically, working backwards from the output layer to the input layer.
# Backpropagation — chain rule applied to a 2-layer network
#
# Architecture: x → [W1, b1] → ReLU → [W2, b2] → sigmoid → ŷ → BCE loss
#
# Forward pass (store all intermediate values):
z1 = W1·x + b1 # pre-activation, layer 1
a1 = ReLU(z1) # activation, layer 1
z2 = W2·a1 + b2 # pre-activation, layer 2
ŷ = sigmoid(z2) # prediction
L = -[y·log(ŷ) + (1−y)·log(1−ŷ)] # binary cross-entropy loss
#
# Backward pass (chain rule, right to left):
#
# ∂L/∂ŷ = (ŷ − y) / [ŷ·(1−ŷ)] # derivative of BCE wrt ŷ
# ∂ŷ/∂z2 = ŷ·(1−ŷ) # derivative of sigmoid wrt z2
# ∂L/∂z2 = ∂L/∂ŷ · ∂ŷ/∂z2 = ŷ − y # simplified: error signal
#
# ∂L/∂W2 = ∂L/∂z2 · ∂z2/∂W2 = (ŷ−y) · a1ᵀ
# ∂L/∂b2 = ∂L/∂z2 = ŷ − y
#
# ∂L/∂a1 = ∂L/∂z2 · ∂z2/∂a1 = (ŷ−y) · W2ᵀ
# ∂a1/∂z1 = ReLU'(z1) = 1 if z1>0 else 0 # ReLU gradient
# ∂L/∂z1 = ∂L/∂a1 · ∂a1/∂z1 # element-wise
#
# ∂L/∂W1 = ∂L/∂z1 · xᵀ
# ∂L/∂b1 = ∂L/∂z1
#
# Update: W1 -= α·∂L/∂W1, W2 -= α·∂L/∂W2, etc.
# Modern frameworks (PyTorch, JAX) compute all this automatically via autograd.
The Chain Rule Intuition
The chain rule states: if y depends on u, and u depends on x, then dy/dx = dy/du × du/dx. Backprop applies this recursively. The gradient at each layer is computed by multiplying the gradient coming from the next layer by the local gradient (derivative of the layer's own transformation). This decomposes what would be a computationally infeasible direct computation into a sequence of manageable local operations.
Automatic Differentiation
Modern frameworks (PyTorch, JAX, TensorFlow) implement reverse-mode automatic differentiation ("autograd"). During the forward pass, the framework builds a computational graph. The backward pass traverses this graph in reverse, computing exact gradients without any manual derivation. This is not symbolic differentiation or numerical approximation — it is exact, efficient automatic calculus.
Vanishing and Exploding Gradients
During backprop, gradients are multiplied layer by layer. If weights are initialised such that activations consistently shrink (e.g. sigmoid outputs in [0,1] multiplied many times), gradients reaching early layers can be effectively zero — the vanishing gradient problem. Conversely, if weights are large, gradients can grow exponentially — exploding gradients. Solutions: careful initialisation (He for ReLU, Glorot/Xavier for tanh/sigmoid), batch normalisation, residual/skip connections, and gradient clipping.
🚀 Optimisers
Plain SGD applies the same learning rate to all parameters. In practice, some parameters need large updates (if they are rarely activated), others need small ones (if they are frequently updated). Adaptive optimisers address this by maintaining per-parameter learning rates that are adjusted based on gradient history.
| Optimiser | Key Idea | Hyperparameters | Pros / Cons |
|---|---|---|---|
| SGD + Momentum | Accumulate an exponentially weighted moving average of past gradients ("velocity"). This dampens oscillations and accelerates convergence in the right direction. | lr, momentum (typically 0.9) | Pros: simple, well-understood, good generalisation with tuning. Cons: sensitive to lr; requires careful scheduling. Still used for large-scale vision training. |
| AdaGrad | Divides the learning rate by the square root of the cumulative sum of squared past gradients. Parameters with many large gradients get smaller effective lr; sparse parameters get larger lr. | lr, epsilon (numerical stability) | Pros: good for sparse data (NLP bag-of-words). Cons: accumulated sum monotonically increases → learning rate eventually → 0 (learning stops). Rarely used today. |
| RMSProp | Fixes AdaGrad's decaying lr by using an exponentially weighted moving average of squared gradients instead of the full cumulative sum. | lr, decay (typically 0.9), epsilon | Pros: works well for RNNs; non-monotonic lr decay. Cons: still requires lr tuning. Proposed by Hinton in a Coursera lecture (unpublished). |
| Adam | Combines momentum (first moment: mean of gradients) and RMSProp (second moment: mean of squared gradients). Both moments are bias-corrected for the first few steps when they are underestimated. | lr (1e-3), β₁ (0.9), β₂ (0.999), epsilon (1e-8) | Pros: fast convergence; robust to lr choice; de facto standard. Cons: can generalise slightly worse than well-tuned SGD+momentum on some tasks. |
| AdamW | Adam with decoupled weight decay: L2 regularisation is applied directly to weights rather than being folded into the gradient (which distorts the adaptive scaling). Better regularisation. | lr (1e-4), β₁, β₂, weight_decay (0.01-0.1) | Pros: the standard for Transformer training; better generalisation than Adam. Nearly universal in LLM pre-training (GPT, BERT, LLaMA all use AdamW). |
Adam is the Default Starting Point
When in doubt, start with AdamW, learning rate 1e-4, default β values. It will almost certainly work reasonably well out of the box. If you are training a large vision model from scratch (ImageNet, etc.), SGD + momentum + cosine annealing schedule is worth the extra tuning effort as it often achieves 1-2% better top-1 accuracy. For fine-tuning pre-trained models or training Transformers, AdamW is typically the right choice. The framework default values (PyTorch, Keras) are sensible starting points.
Learning Rate Schedulers
The learning rate typically needs to decrease over training. A too-large lr in late training causes the parameters to oscillate around minima without settling. Common schedules:
- Step decay: reduce lr by factor γ every N epochs
- Cosine annealing: smooth decrease following a cosine curve; widely used with SGD
- Warmup + decay: linear warmup for K steps, then decay; standard for Transformers
- OneCycleLR: increases lr to max then decreases; fast convergence (super-convergence)
- ReduceLROnPlateau: reduce when validation loss stops improving
Regularisation During Training
Preventing overfitting is as important as learning quickly. Key techniques used alongside backprop:
- L2 weight decay (AdamW): penalises large weights, keeps parameters small and general
- Dropout: randomly zero activations during training; forces redundant representations
- Batch Normalisation: normalises activations; mild regularisation effect
- Data augmentation: synthetic training variety; not technically a loss term but acts as regulariser
- Early stopping: halt training when validation loss stops improving