Gradient Descent
The simple algorithm behind every neural network, explained from first principles to production-scale training.
Recently, I've been spending most of my weekends building stuff with AI. My current project is a Gujarati (My mother tounge) Q&A app - it listens to live religious discourses and answers questions about them in real time, from the religious books. It's mostly plumbing: speech-to-text APIs, , vector databases, LLMs - nothing fancy. I wire things together, tweak prompts, and ship features. The models learn; I just use them.
Which is fine, until it's not. The more I build, the more it feels like I'm a tourist in my own field. I can tell you which embedding model to use, which chunking strategy works, how to structure a . But ask me how the model actually learned those - like, what's really happening inside - and I would be drawing a blank. I've got very little idea about that.
The moment that got me was a conversation with a friend. I was rambling about some AI thing, and he cut me off with a simple question: "Wait, how does the model even know which direction to go, though?" I opened my mouth to explain and... I couldn't. Not in a way that made sense. You see, I could recite the gradient descent formula, I could draw the diagram but I couldn't tell him the actual story of what was happening. I did not have that intuitive understanding of how the model learned and that bothered me a lot that day.
So I did what I should have done the first time: Start from zero. No equations for the first two days. Just hills, valleys, and blindfolded hikers. I rebuilt my understanding piece by piece until the math finally felt obvious instead of arbitrary.
This article is that rebuild. I wrote it because I needed it to exist - something that explains gradient descent from first principles, intuitively so that I can finallyunderstand how each piece works. We'll start with a single ball rolling down a hill and end with how trillion-parameter models get trained across thousands of GPUs. Let's go.
The Problem: Learning at Scale
Before we dive into gradient descent, let's understand the problem it solves.
What does it mean for a machine to "learn"?
To explain this, I'm going to use one of the most used example in all of machine learning: predicting house prices based on their size. I know, I know. Every tutorial uses this. But here's the thing - it's used for a reason. You already understand houses, prices, and sizes. There's no mental overhead. That means we can pour all our attention into the actual concept, instead of decoding the scenario.
Say you have data on houses - their sizes and their actual sale prices. You want to find a relationship that lets you predict the price of a new house given its size. In mathematical terms, you're looking for a function that maps inputs (size) to outputs (price).
Let's start simple. Assume the relationship is linear:
price = w × size + b
Here, w (weight) is the slope and b (bias) is the y-intercept. The question is: what values of w and b give us the best predictions for the price?
Here's the critical insight: for any choice of w and b, our line makes predictions. Some choices produce predictions close to reality. Others are way off. We need a way to score how wrong we are - a single number that tells us whether a particular line is good or terrible.
That number is the . It takes your predictions and the actual prices, and returns a score. A common one is Mean Squared Error (MSE): take the difference between predicted and actual price for each house, square it (to make all errors positive), and average them. If your line passes exactly through every data point, the loss is zero. If your line is nowhere close, the loss is huge.
This transforms learning into a pure optimization problem. We are no longer "finding a line." We are minimizing a function. The loss function defines a landscape - a surface where every point represents a possible choice of w and b, and the height at that point is how badly that choice performs (because we want to minimize the loss - distance between actual value and the predicted value). Training means finding the lowest point in this landscape: the values of w and b that make the loss as small as possible. What this means is - we are finding the best values for the parameters that make our model predict the actual values as closely as possible.
But here's where it gets interesting. In modern AI, we're not optimizing 2 parameters. Large language models have hundreds of billions of parameters. Even "small" models like have 110 million parameters. You can't just try every combination - even at a coarse two-points-per-axis grid you'd be enumerating 2110,000,000 configurations, which dwarfs the ~1080 atoms in the observable universe by an astronomical margin.
We need an algorithm that can navigate this enormous landscape without checking every point. So, we need a way to look at where we currently stand, figure out which direction goes downhill (because we want to find the minimum), and take a step. That's the idea behind gradient descent.
Brute Force vs Gradient Descent
Brute force checks every point in the search space. Gradient descent uses local information to find minima in far fewer steps. Green dots mark global minima, red is a local maximum.
The demo above compares brute force search with gradient descent on a simple one-dimensional problem. Brute force checks thousands of points to find the minimum. Gradient descent uses local slope information to find it in just a few steps. Now imagine scaling this to millions of dimensions - brute force becomes literally impossible, while gradient descent would keep working.
The Intuition: Following the Slope
So we need something smarter than brute force. And here's where the intuition comes in. Imagine you're blindfolded and dropped somewhere on a hilly landscape. Your goal is to find the lowest point - the valley. How would you do it?
You'd feel the ground around you. If it slopes down to your left, you'd step left. If it slopes down forward, you'd step forward. You'd keep taking small steps in the direction that goes downhill. Eventually, you'd reach a point where every direction leads uphill - that's the bottom of the valley.
That's gradient descent. Instead of blindly searching, you use the slope of the landscape to guide your steps. The "landscape" is your loss function plotted against your parameters. The "slope" is the gradient.
The key insight is that local information is enough. You don't need to see the entire landscape to find the bottom. You just need to know which direction is down from where you're standing. This scales beautifully: whether you have 2 parameters or 2 billion, the approach is the same. Feel the slope, step downhill, repeat.
Gradient Descent: Ball Rolling Downhill
xnew = xold - α × ∂L/∂x = 1.75The green arrow shows the direction of descent (opposite to gradient). The ball moves in that direction by α × gradient each step.
In the demo above, you can see how gradient descent works. The ball starts at some position on the curve. At each step, it computes the gradient (the slope at that point) and moves in the opposite direction - downhill. Click anywhere on the curve to place the ball, then watch it roll down to the minimum at x = 0.
Notice something important: the ball doesn't jump directly to the minimum. It takes small steps, gradually moving downhill. This is crucial - gradient descent is an iterative algorithm. It improves gradually over time, not all at once. This might seem like a limitation, but it's actually what makes it scalable. Each iteration is cheap, so we can afford to do millions of them.
The Math: Gradients and Derivatives
The gradient is just a fancy word for "which direction is uphill, and how steep is it?" More precisely, it's a vector of . Each element tells you: if I increase this parameter slightly, how much does the loss change?
Let's say you have two parameters, w and b. The gradient is:
∇L = [∂L/∂w, ∂L/∂b]If ∂L/∂w is positive, it means increasing w increases the loss - so we should decrease w. If it's negative, increasing w decreases the loss - so we should increase w.
The gradient points in the direction of steepest ascent. Since we want to minimize, we go in the opposite direction - the direction of steepest descent. Hence the name.
Think of partial derivatives this way: imagine you're standing on a hillside and you want to know how steep it is. But hills slope differently in different directions. The partial derivative ∂L/∂w tells you how steep the hill is if you walk purely east (changing w while keeping b fixed). The partial derivative ∂L/∂b tells you how steep it is if you walk purely north (changing b while keeping w fixed). Together, they tell you the full slope at your current position.
Partial Derivative Explorer
Each 1D plot freezes one variable and shows how loss changes with the other. The green tangent has slope = the partial derivative. The orange triangle shows a finite step Δ — its hypotenuse is the secant, slope = ΔL/Δ. Shrink Δ and the secant snaps onto the tangent.
The demo above visualizes partial derivatives as the slope of a 1D curve. The 2D contour plot on top is the full loss landscape; click anywhere to move the orange point. The two plots below are slices through that landscape: the left freezes b and shows L as you vary w, the right freezes w and shows L as you vary b. The green tangent line on each slice has slope equal to the partial derivative at that point. The orange right-triangle is a finite step Δ - its horizontal leg is Δw (or Δb), its vertical leg is the resulting ΔL, and its hypotenuse is the secant whose slope is the numerical approximation ΔL/Δ. As you shrink Δ, the secant snaps onto the tangent - which is exactly the limit definition of a derivative.
The update rule is beautifully simple:
wnew = wold - α × ∂L/∂wWhere α (alpha) is the learning rate - how big of a step we take. This single equation is the heart of nearly all modern AI training.
Let's break down what this equation is saying. At each step:
- Compute the gradient
∂L/∂w- which direction is uphill - Multiply by -1 to get the downhill direction
- Scale by
αto control step size - Add this to your current position to get the new position
This is repeated thousands or millions of times until convergence - when the gradient becomes so small that further steps don't meaningfully improve the loss.
The Algorithm: Step by Step
Let's look at the complete gradient descent algorithm in pseudocode. This is the same algorithm whether you're fitting a line to 10 points or training a neural network with a billion parameters:
# Gradient Descent Algorithm
def gradient_descent(data, initial_params, learning_rate, max_iterations):
params = initial_params
for iteration in range(max_iterations):
# 1. Compute loss and gradients
loss = compute_loss(params, data)
gradients = compute_gradients(params, data)
# 2. Update parameters
for p in params:
params[p] = params[p] - learning_rate * gradients[p]
# 3. Check convergence
if norm(gradients) < threshold:
break
return paramsThat's it. Three main steps:
- compute gradients
- update parameters
- check if we're done
The simplicity is deceptive - this basic structure powers everything from linear regression to GPT-4.
Initialization - Where do we start?
To get started, we typically initialize all parameters. For simple problems, initializing all parameters to zero works fine. But for neural networks, this is a disaster - if all neurons start the same, they'll stay the same (this is called the "symmetry problem").
Instead, we typically initialize with small random values. Techniques like and set the initial values based on the number of inputs and outputs, helping gradients flow properly in the early stages of training.
Convergence Criteria
When do we stop? There are several options:
- Fixed iterations: Run for a set number of steps (e.g., 100 )
- Gradient threshold: Stop when gradients become very small (near zero)
- Loss plateau: Stop when loss hasn't improved for several iterations
- Early stopping: Monitor and stop when it starts increasing (prevents overfitting)
Step-by-Step Gradient Descent
∂L/∂w = 4.2500, ∂L/∂b = 2.5000
w ← 2.0000 + (0.1 × -4.2500) = 1.5750
b ← 1.5000 + (0.1 × -2.5000) = 1.2500
Each step computes gradients, then updates parameters by moving against the gradient. Watch how the orange line converges to the green dashed line (the true relationship).
The demo above lets you trace through each iteration of gradient descent step by step. Watch how the parameters (w and b) update, how the loss decreases, and how the line gradually converges to fit the data. You can adjust the learning rate and see how it affects convergence speed and stability.
Seeing It Work: Linear Regression
Let's make this concrete with the simplest possible example: fitting a line to some points. No neural networks, no complexity - just a line.
We have some data points. We want to find the line y = w × x + b that best fits them. "Best" means minimizing the mean squared error between our predictions and the actual values.
Watch gradient descent find the optimal line:
Linear Regression: Fit the Line
y = 0.50x + 3.00Red lines show the error between predictions and actual values. Gradient descent minimizes the mean squared error.
In this demo, you can actually drag the line to see how different values of w (slope) and b (intercept) affect the error. The red lines show the error for each point - gradient descent minimizes the sum of these squared errors. Try dragging the line manually to find a good fit, then click "Auto-Fit" to watch gradient descent do it perfectly.
Notice how the line gradually moves toward the best fit. This is gradient descent in its purest form - iterative improvement guided by gradients. Even though this problem has a closed-form solution (you could compute the optimal w and b directly using linear algebra), gradient descent finds it through successive approximation. This iterative approach is what scales to problems with billions of parameters where closed-form solutions don't exist.
The Learning Rate: A Critical Choice
The learning rate is arguably the most important in machine learning. Get it wrong, and your model either learns nothing or explodes.
Too small: Your model learns extremely slowly. It takes forever to converge. Each step is so tiny that you might need millions of iterations to get anywhere. It's like trying to walk from Mumbai to Delhi by taking baby steps. Worse, you might get stuck in flat regions or near because you don't have enough momentum to escape.
Too large: Your model overshoots the minimum. It bounces around wildly, never settling down. The loss might even increase instead of decreasing. Imagine trying to land a plane by alternating between diving and climbing - you'd crash. In extreme cases, the loss can explode to infinity (numerical overflow).
Just right: The sweet spot. Fast enough to make progress, small enough to converge smoothly. Finding this sweet spot is one of the key skills in training ML models.
Learning Rate Comparison
Loss function: L(w,b) = w² + 4b². Elliptical contours show equal loss levels. Watch how different learning rates behave from the same starting point.
In the demo above, you can see multiple learning rates racing simultaneously on the elliptical bowl L(w, b) = w² + 4b². Toggle different rates on/off to compare them. Watch how a tiny learning rate (0.01) creeps along slowly, a moderate one (0.1) settles into the minimum cleanly, and an aggressive one (0.3) overshoots so badly along the steep b-axis that the path blows up and gets tagged as Diverged!. Click anywhere on the contour plot to pick a new starting point and re-run.
Learning Rate Schedules
In practice, using a fixed learning rate is often suboptimal. Think of it like driving: you want to go fast on the highway, but slow down when you approach your destination. Learning rate schedules implement this intuition:
- Step decay: Reduce LR by a factor (e.g., 0.1) every N epochs
- Exponential decay: LR decreases exponentially: LR = LR0 × e-kt
- Cosine annealing: LR follows a cosine curve, starting high and smoothly decreasing
- Warm restarts: Periodically reset LR to a higher value (helps escape local minima)
- One-cycle: LR increases then decreases in a single cycle (often works best in practice)
Finding a Good Learning Rate
So, how do we know what learning rate to use when training a model? There is no definite answer. However, a technique called the "LR Range Test" helps find a good starting point. The idea: start with a very small LR, train for a few while exponentially increasing the LR, and plot the loss. The LR where loss decreases fastest (steepest slope) is a good choice.
Learning Rate Finder
Train briefly at each candidate LR. Plotted left to right: loss plateaus when the rate is too low, drops sharply through the sweet spot, then explodes once the step size pushes us off the loss surface (too high). The good pick is wherever the slope is steepest — one decade below the divergence cliff.
The LR Finder demo above simulates this process. It tests different learning rates and identifies where the loss decreases fastest - that's your optimal learning rate. Too small and learning is slow; too large and the loss explodes.
In practice, most practitioners back off from the steepest-slope LR by roughly an order of magnitude (so if the steepest descent happens at 1e-2, you'd train at 1e-3). The steepest point is right at the edge of the divergence cliff - stepping down ~10× leaves headroom so a noisy batch doesn't push you over.
Stochasticity: Faster but Noisier
So far, I've described computing gradients using all the data. This is called batch gradient descent. It gives you the true gradient - the exact direction of steepest descent. But there's a problem.
If you have millions of data points, you need to process all of them just to take a single step. That's extremely slow. And in the age of big data, datasets can be billions of examples. Waiting to see every example before updating would make training impossibly slow.
Stochastic Gradient Descent (SGD)
The solution: compute the gradient using just one randomly chosen data point, then take a step. Repeat. This is Stochastic Gradient Descent.
This is much faster per step - you process one example instead of millions. But the gradient estimate is noisy. One data point might not be representative of the whole dataset. The path to the minimum becomes jagged and erratic.
Surprisingly, this noise can actually help! It can help escape and find better solutions. Sometimes random jiggling helps you discover a path you wouldn't find by being too precise. The noise acts as a form of implicit , helping the model better.
Mini-Batch Gradient Descent
The best of both worlds, and what everyone actually uses in practice. Compute the gradient using a small batch of data - typically 32, 64, or 128 examples.
You get a reasonably accurate gradient estimate (averaging over multiple examples reduces noise) while still being computationally efficient. Plus, modern GPUs are optimized for processing batches in parallel - you get much better hardware utilization.
Mini-batch size is another important hyperparameter. Larger batches give more accurate gradients but require more memory and compute per step. Smaller batches are faster per step but noisier. Research has shown that large batch training can get stuck in sharp minima that generalize poorly, while small batches tend to find flatter minima.
Batch vs Mini-batch vs SGD
Features are pre-orthogonalized so the Hessian is exactly diag(1, 4) — that's the elliptical bowl above. The label noise is the only source of stochasticity: it's what makes single-sample gradients disagree with the true gradient.
Each tick is one update for every variant — but look at the samples counter: Batch burns 200 gradient evaluations per step, Mini-batch uses 16, SGD uses 1. SGD reaches the basin in far fewer evaluations but jitters around w* because its gradient is the gradient of a single noisy sample, not the full loss.
Watch how the three variants behave differently in the demo above. Batch gradient descent (green) takes smooth, confident steps because it uses the exact gradient. SGD (orange) is erratic and noisy because each step uses a different random sample. Mini-batch (purple) strikes a balance. Toggle the gradient arrows to see the direction each method computes at each step.
The Landscape: What Makes Optimization Hard
Gradient descent sounds elegant, but real-world loss surfaces are messy. Here are some challenges that kept me up at night when training my models.
Local Minima
Remember the blindfolded hiker analogy? What if there are multiple valleys? You might end up in a small valley when there's a much deeper one nearby. Gradient descent only sees the local landscape - it has no way to know if there's something better over the next hill.
For simple problems like linear regression, this isn't an issue - the loss surface is convex (bowl-shaped), with exactly one minimum. But neural networks have incredibly complex, non-convex loss surfaces with countless local minima.
The good news: research has shown that in high-dimensional spaces (millions of parameters), most local minima are actually pretty good - close to the global minimum in terms of loss. The landscape is more like an egg carton than a mountain range. Additionally, the noise from SGD helps escape shallow local minima.
Saddle Points
These are trickier than local minima. A saddle point is where the gradient is zero, but it's not a minimum - it's a minimum in some directions and a maximum in others. Like the middle of a horse's saddle.
At a saddle point, basic gradient descent gets stuck. The gradient is zero, so there's no direction to move. Yet there are clearly better solutions nearby - you just can't see them from where you're standing. In high-dimensional spaces, saddle points are actually much more common than local minima.
The solution? Noise! SGD and mini-batch gradient descent naturally escape saddle points because the gradient estimates are noisy. The noise perturbs you off the saddle, and the gradient in the escape direction takes over. Momentum also helps - it can carry you through saddle points.
Plateaus and Flat Regions
Sometimes the loss surface is nearly flat over a large region. The gradient is close to zero, so steps are tiny. Training seems to stall. You might think you've converged, but actually you're just in a flat area with better regions beyond.
Plateaus are especially problematic because they slow learning to a crawl. You might need thousands of iterations just to get across a flat region. This is where adaptive learning rates (like in Adam) really shine - they increase the effective learning rate in flat regions.
Sharp vs Flat Minima
Not all minima are created equal. Some are sharp valleys - the loss increases rapidly as you move away. Others are flat plateaus - the loss stays low over a wide region. Research suggests that flat minima generalize better than sharp ones.
Why? Because the training set is just a sample of the true data distribution. A sharp minimum might fit the training data perfectly but be sensitive to small changes in the data. A flat minimum is robust - small perturbations don't hurt performance much. This is one reason small batch sizes often work better: the noise encourages finding flatter minima.
Optimization Challenges
Two valleys - gradient descent settles in the nearest one
Global minimum at x ≈ -2 (deeper), local minimum at x ≈ 2 (shallower)
The demo above shows these scenarios. Switch between the three tabs to explore each challenge. Click anywhere on the curve to place the ball at a new starting point. Notice how gradient descent can settle in local minima, stall at stationary points where the gradient vanishes but it's not a minimum (the 1D inflection point shown here shares this failure mode with true higher-dimensional saddle points), or slow down dramatically on plateaus.
Modern Optimizers: Beyond Vanilla Gradient Descent
Vanilla gradient descent has limitations. Over the years, researchers have developed smarter variants that address these challenges. When I was fine-tuning models for my app, understanding these made a real difference.
Momentum
Imagine a ball rolling down a hill. It doesn't just follow the local slope - it builds up speed. Even if it hits a small bump (local minimum), its momentum carries it through.
Momentum does exactly this. Instead of stepping in the direction of the current gradient, it maintains a vector that carries forward the running direction of motion:
v ← β · v + ∇L
w ← w − α · vReading the update piece by piece:
vis the velocity - initially zero, it accumulates a discounted sum of past gradients.β(typically 0.9) is the momentum coefficient: each step keeps β fraction of the old velocity and adds the fresh gradient on top. With β = 0.9, a gradient fromtsteps ago still contributes ~0.9t to today's velocity, giving an effective horizon of about1 / (1 − β) ≈ 10steps.αis the learning rate, applied to the velocity (not to the raw gradient).
The intuition: if the gradient keeps pointing the same way (consistent slope), successive contributions add up and v grows large - you accelerate. If gradients oscillate (jagged terrain pulling you side to side), they partially cancel inside the running sum and the velocity stays small in those directions, so the path smooths out. This is what lets momentum blast through saddle points, escape shallow local minima, and damp the zigzag you'd otherwise see in narrow valleys.
RMSprop
Different parameters might need different learning rates. A parameter that rarely updates should take bigger steps when it does. A parameter that updates constantly should take smaller steps. This is especially true in neural networks where different layers learn at different speeds.
RMSprop (Root Mean Square Propagation) adapts the learning rate per parameter based on the recent history of gradients:
s ← β · s + (1 − β) · (∇L)2# running average of squared gradientsw ← w − α · ∇L / (√s + ε)# rescaled updateParameters with large recent gradients get smaller effective learning rates. Parameters with small recent gradients get larger ones. It's automatic adjustment. The squared gradients (remembered in s) act as a kind of memory of how "active" each parameter has been.
Adam
Why choose between momentum and adaptive learning rates when you can have both? (Adaptive Moment Estimation) combines the best of both worlds. It maintains both a momentum term (first moment) and an adaptive learning rate term (second moment).
Adam computes:
m ← β1 · m + (1 − β1) · ∇L# 1st moment (momentum)v ← β2 · v + (1 − β2) · (∇L)2# 2nd moment (adaptive LR)m̂ ← m / (1 − β1t)# bias correction for mv̂ ← v / (1 − β2t)# bias correction for vw ← w − α · m̂ / (√v̂ + ε)# parameter updateAdam is the default choice for most practitioners. When in doubt, start with Adam. It works well across a wide range of problems with minimal tuning. That said, it's not always the best - for some problems, plain SGD with momentum actually generalizes better. The adaptivity of Adam can sometimes lead to or finding sharper minima.
AdamW
A variant of Adam called AdamW separates from the gradient update. In regular Adam, weight decay gets multiplied by the adaptive learning rate, which can cause regularization to behave weirdly. AdamW applies weight decay directly to the weights, which works better in practice. Most modern implementations use AdamW instead of Adam.
When to Use Which?
- Adam/AdamW: Default choice. Good starting point for most problems.
- SGD + Momentum: When you need best generalization. Often beats Adam on vision tasks after longer training.
- RMSprop: Rarely used alone now, but conceptually important.
- Lion: A newer optimizer (2023) that uses only sign of gradients. Shows promise for large models.
Optimizer Race: SGD vs Momentum vs Adam
The Rosenbrock "banana" function is a classic optimization test. Watch how Adam typically navigates the curved valley most efficiently.
Watch the three optimizers race on the Rosenbrock "banana" function above. The curved valley makes this a challenging optimization problem. Notice how basic SGD struggles to navigate the curve, while Momentum helps it build up speed in consistent directions. Adam usually reaches the minimum (the red target) fastest by adapting its learning rate for each parameter.
Second-Order Methods: Using Curvature
All the methods we've discussed so far are first-order methods - they only use the gradient (first derivative). But there's more information available: the curvature of the loss surface, captured by the second derivative ( ).
Think of it this way: if you're at the bottom of a valley, the gradient tells you which direction is up. But if the valley is narrow and steep-walled, you might want to take smaller steps to avoid bouncing back and forth. If it's wide and gentle, you can take bigger steps. The curvature tells you about the shape of the valley.
Newton's Method
The classic second-order method is Newton's method. Instead of just following the gradient, it solves for the optimal step using both gradient and curvature:
w = w - H⁻¹ × ∇LWhere H is the Hessian matrix of second derivatives. Newton's method can converge in far fewer iterations than gradient descent - often 10x to 100x fewer. In fact, for quadratic functions, it converges in a single step!
So why don't we use Newton's method for deep learning? The problem is the Hessian. For n parameters, the Hessian is an n×n matrix. For a modest neural network with 1 million parameters, that's 1 trillion entries. Computing and inverting this matrix is impossibly expensive.
Quasi-Newton Methods
Researchers developed methods that approximate the Hessian without computing it directly. L-BFGS (Limited-memory Broyden-Fletcher-Goldfarb-Shanno) maintains a low-rank approximation of the Hessian using only recent gradient information. It works well for smaller problems but doesn't scale to deep neural networks.
Approximate Second-Order Methods
Some modern approaches try to get second-order benefits without the full cost:
- Natural Gradient: Uses the instead of Hessian
- K-FAC (Kronecker-Factored Approximate Curvature): Approximates the Fisher matrix efficiently for neural networks
- Second-order optimization in subspaces: Only compute curvature in important directions
These methods show promise but haven't displaced first-order methods like Adam as the default choice. For now, first-order methods rule deep learning.
Backpropagation: Computing Gradients Efficiently
I've been glossing over something important. I've said "compute the gradient" many times - as if it were a single operation - but for a neural network with millions of parameters, what does that even mean concretely? You have to compute a partial derivative for every weight. So how does that happen?
The naive answer is: compute each partial derivative numerically. Nudge a weight by a tiny ε, run the whole network forward, see how much the loss changed, divide by ε. Repeat for every weight. For a 110-million-parameter model, that's 110 million forward passes through the network to produce one gradient. And you need a fresh gradient for every single training step. The math is right, but the bookkeeping is unworkable - it would take longer to compute one gradient than the entire universe has existed.
So how does it actually happen?
The answer is , and the trick - it really is one trick - is to reuse work. The loss depends on the output. The output depends on the last layer's weights. The last layer's inputs depend on the second-to-last layer's weights. And so on, all the way back. Once you've already computed how the loss responds to layer N's output, computing how it responds to layer N-1's weights is mostly a matter of multiplying by one more thing. You don't redo the work you already did.
The Chain Rule
That "multiplying by one more thing" is the chain rule from calculus. If z depends on y, and y depends on x, then:
dz/dx = dz/dy × dy/dxRead it as: "if I wiggle x by a tiny amount, that wiggles y, which wiggles z." The total effect is just the product of the two sensitivities. Calculus 101.
In a neural network, the chain is much longer:
input → layer1 → layer2 → ... → layerN → output → loss
To find how the loss changes with respect to a weight buried in layer 1, you multiply derivatives all the way through:
∂L/∂w1 = ∂L/∂output × ∂output/∂layerN × ... × ∂layer2/∂layer1 × ∂layer1/∂w1Most of those terms - the ones in the middle - are the same no matter which weight in layer 1 you're asking about. So you compute them once, cache them, and reuse them for every weight. That sharing is what turns "millions of forward passes" into "one forward pass and one backward pass."
Forward and Backward Passes
Training is two phases that always run back-to-back:
- Forward pass: push the input through the network, layer by layer, stashing intermediate values along the way. End up with a prediction and a loss.
- Backward pass: walk the same chain in reverse. At each layer, multiply the gradient flowing in by that layer's local derivative to get the gradient flowing out the other side. By the time you reach the first layer, you have ∂L/∂w for every weight in the network.
The key insight worth pausing on: you compute gradients for millions of parameters in roughly the same time as one forward pass. Not millions of forward passes - one. One pass forward, one pass backward, gradients for everything. That ratio is what makes deep learning feasible at all. Take it away and the entire field stops; there's no clever optimizer that fixes a gradient you can't afford to compute.
Forward & Backward Pass
Forward pass (green) computes activations left-to-right. Backward pass (orange) walks the chain rule right-to-left, multiplying local derivatives to produce ∂L/∂w at every edge. Animated edges show the currently active layer.
Step through the demo above to watch a forward pass compute activations left-to-right, then a backward pass walk the chain rule right-to-left and produce a gradient at every edge. Crank the depth slider to 6 and watch the per-layer gradient norms at the bottom — with a weight scale near 1, gradients shrink by ~0.25× each layer (sigmoid's max derivative is 0.25), so by the time they reach the first layer they're effectively zero. That's the vanishing-gradient problem in one screen. Bump the weight scale up past 3 and the opposite happens.
Vanishing and Exploding Gradients
The same chain rule that makes backprop efficient is also what makes deep networks fragile. You're multiplying many derivatives together - one per layer the gradient has to travel through. If each is, say, 0.5 and you have 50 layers, the gradient reaching layer 1 is 0.550 ≈ 10-15. Effectively zero. The early layers stop learning. That's the vanishing gradient problem.
Flip it: if each derivative is 1.5, you get 1.550 ≈ 6 × 108. The gradient blows up, the weight update is enormous, the loss flips to NaN, and training is over. Equally bad in the opposite direction.
For roughly the first decade of neural-network research, this was the wall. Networks more than a few layers deep just wouldn't train - the early layers received gradients indistinguishable from noise. The fix wasn't one breakthrough but a stack of them, each attacking the problem from a different angle:
- Provide shortcuts for gradient flow
- Careful initialization: Xavier/He initialization keeps gradients in a healthy range
- Normalizes activations to prevent gradient issues
- ReLU avoids the vanishing gradient problem of sigmoid/tanh
Backpropagation deserves its own article - there's a lot more to say about computational graphs, automatic differentiation, and how PyTorch and JAX actually pull this off behind the scenes. Maybe I'll write that one next. For now, the thing to internalize is that backprop is the unsung machinery underneath every model you've ever used: quietly computing millions of gradients per training step, almost nobody thinks about it, and without it the whole field doesn't exist.
Training at Scale: Distributed and Large Models
So far, we've talked about training on a single machine. But modern AI models - GPT-4, Claude, Gemini - are trained on thousands of GPUs working together. How does gradient descent work at that scale?
Data Parallelism
The simplest form of distributed training. You have multiple GPUs, each with a copy of the model. You split the batch across GPUs, each computes gradients on its subset, then you average the gradients and update all models together.
This is straightforward but has limits. You can only scale to as many GPUs as you can split a batch. Very large batches can hurt generalization (they find sharper minima). And each GPU needs enough memory to hold the entire model.
Model Parallelism
When a model is too big for one GPU, you split it across multiple GPUs. Different layers live on different GPUs. During the forward pass, activations are passed from GPU to GPU. During backprop, gradients flow backward through the same path.
This is more complex than data parallelism because GPUs must communicate activations and gradients during training. But it allows training models that wouldn't fit on any single GPU.
Pipeline Parallelism
A hybrid approach that combines data and model parallelism. The model is split into stages, each on a different GPU. Multiple micro-batches flow through the pipeline simultaneously - while one GPU processes batch 1, the next GPU processes batch 2, and so on. This keeps all GPUs busy and maximizes throughput.
Gradient Accumulation and Compression
Communication between GPUs is a bottleneck. Techniques to reduce it:
- Gradient accumulation: Compute gradients on multiple batches before updating, simulating a larger batch
- Mixed precision: Use 16-bit floats instead of 32-bit, halving communication
- Gradient compression: Quantize or sparsify gradients before sending
- Local SGD: Each GPU updates independently for several steps, then sync
Federated Learning
An interesting variant: train on data distributed across many devices (like smartphones) without centralizing the data. Each device computes gradients on its local data, sends only the gradients (not the data) to a central server, which averages them and sends back updates. This preserves privacy while still allowing collective learning.
Advanced Challenges in Deep Learning
Training large neural networks presents unique optimization challenges beyond what we've discussed. Here are some that practitioners face:
Catastrophic Forgetting
When you train a neural network on task A, then train it on task B, it often "forgets" how to do task A. The gradients for task B overwrite the knowledge from task A. This is a fundamental limitation of gradient descent - each update modifies all parameters, potentially destroying previous learning.
Solutions include:
- Rehearsal: Mix old data with new data during training
- Elastic Weight Consolidation (EWC): Protect important parameters from changing too much
- Progressive networks: Add new capacity for new tasks
- Meta-learning: Learn how to learn without forgetting
Mode Collapse in GANs
In Generative Adversarial Networks, a generator and discriminator play a minimax game. The generator tries to fool the discriminator, the discriminator tries to detect fakes. This isn't pure gradient descent - it's a dynamic system that can be unstable.
One common failure is mode collapse: the generator learns to produce only a few types of outputs that reliably fool the discriminator, instead of capturing the full diversity of the training data. Training GANs is notoriously finicky and requires careful tuning of learning rates and architecture.
Gradient Starvation
In very deep networks or certain architectures, gradients can become concentrated in a few neurons while others receive almost no gradient. This "rich get richer" dynamic means some parts of the network never learn. Solutions include careful initialization, skip connections, and normalization techniques.
Non-Stationary Objectives
In , the target (what you're trying to predict) keeps changing as the model improves. This makes optimization much harder than supervised learning where targets are fixed. Techniques like target networks (slowly updating copies of the model) help stabilize training.
Where is Gradient Descent Used?
Everywhere. I mean it. If you've interacted with any AI system today, gradient descent was involved in training it.
Large Language Models: ChatGPT, Claude, Gemini - all trained using gradient descent. When these models "learn" to predict the next word, they're doing millions of gradient updates across billions of parameters. The scale is mind-boggling, but the core algorithm is the same one we've been discussing.
These models have hundreds of billions of parameters and are trained on thousands of GPUs for months. Each training step involves computing gradients for all those parameters, averaging them across GPUs, and updating the weights. Yet it's still gradient descent - just at an extraordinary scale.
Image Recognition: Every photo app that recognizes faces, every self-driving car that identifies pedestrians, every medical system that spots tumors - trained with gradient descent. use backpropagation to learn features from pixels, from simple edges in early layers to complex objects in deep layers.
Recommendation Systems: Netflix suggesting your next show, Spotify creating playlists, Amazon recommending products - all powered by models trained with gradient descent. These systems learn embeddings (vector representations) of users and items by optimizing prediction accuracy.
Voice Assistants: Siri, Alexa, Google Assistant - the speech recognition and natural language understanding models are trained with... you guessed it.
Scientific Applications: Drug discovery (predicting molecular properties), protein folding (AlphaFold), climate modeling, physics simulations - gradient descent helps optimize complex models in all these domains.
My Gujarati Q&A App: Even my small project uses gradient descent. embedding models for Gujarati text, training the question-answering components - all gradient descent under the hood. Every time I train, I'm reminded that the same algorithm training GPT-4 is training my tiny model.
It's remarkable that such a simple idea - "take small steps downhill" - is the foundation of all these diverse applications.
Why Gradient Descent Won
Gradient descent isn't the only optimization algorithm. Why did it become the dominant approach for training neural networks? Several factors:
Universality
Gradient descent works for any differentiable function. Whether you're fitting a line, training a neural network, or optimizing a physics simulation - if you can compute gradients, you can use gradient descent. This generality makes it applicable across domains.
Scalability
Gradient descent scales from 2 parameters to 2 trillion parameters. The computational cost per iteration grows linearly with the number of parameters, which is remarkably efficient. Each step is cheap, so you can afford to take millions of them.
Parallelizability
Computing gradients across a batch of data is embarrassingly parallel. Modern GPUs can compute gradients for thousands of examples simultaneously. This hardware efficiency is crucial for training large models. The algorithm fits modern compute hardware perfectly.
Composability
Neural networks are compositions of simple functions. Backpropagation efficiently computes gradients through these compositions using the chain rule. This composability enables arbitrarily deep and complex architectures while keeping gradient computation tractable.
Incremental Improvement
Gradient descent provides a smooth path of improvement. At each step, you know exactly what to do to reduce loss. This iterative approach allows you to stop early if needed (early stopping for regularization) and provides visibility into training progress (the loss curve).
Robustness
While we discussed challenges (local minima, saddle points), in practice gradient descent is remarkably robust. With proper initialization and modern optimizers (Adam, etc.), it reliably finds good solutions even in complex, non-convex landscapes. The high dimensionality that makes analysis hard also makes optimization easier (fewer truly bad local minima).
What It Led To: The Gradient Descent Family Tree
Gradient descent isn't an endpoint - it's the foundation of a rich research area. Here are some directions it has spawned:
Evolutionary Algorithms
What if instead of following gradients, you used principles from biological evolution? Genetic algorithms and evolution strategies optimize by maintaining a population of solutions, selecting the best, and creating new ones through mutation and recombination. These gradient-free methods work for non-differentiable problems but are generally slower for differentiable objectives.
Bayesian Optimization
When function evaluations are expensive (e.g., tuning model hyperparameters), Bayesian optimization builds a probabilistic model of the objective and intelligently chooses where to sample next. It's more sample-efficient than grid search but doesn't scale to high dimensions or many function evaluations.
Meta-Learning
Can we learn to learn better? Meta-learning algorithms (like MAML) optimize the learning process itself. The goal is to find initializations that enable fast adaptation to new tasks, or even to learn an optimizer that outperforms gradient descent on specific problem classes.
Neural Architecture Search (NAS)
What if we optimize not just the weights but the architecture itself? NAS uses gradient descent (or other search methods) to find optimal network architectures. This is computationally expensive but has discovered architectures that outperform human-designed ones.
Differentiable Programming
The success of gradient descent in deep learning has inspired a broader movement: making everything differentiable. If your program is differentiable end-to-end, you can optimize it with gradient descent. This includes differentiable rendering, differentiable physics, differentiable logic gates - essentially turning programming into optimization.
Learned Optimizers
Instead of hand-designing update rules (like Adam), can we learn them? Researchers have trained neural networks to generate parameter updates. These learned optimizers can outperform human-designed ones on specific tasks, though they often don't generalize as well to new problems.
The Future: What's Next?
Gradient descent has carried us far, but there are signs we might be approaching its limits - at least in its current form. Some challenges and research directions:
Sample Efficiency
Humans learn from far fewer examples than neural networks. A child can learn what a "giraffe" is from seeing a few pictures. LLMs are trained on trillions of . This gap suggests there's something fundamental about learning that gradient descent hasn't captured yet. Research into few-shot learning, transfer learning, and inductive biases aims to close this gap.
Biological Plausibility
Backpropagation requires perfect symmetry between forward and backward passes - something the brain doesn't seem to have. Research into biologically plausible alternatives (target propagation, equilibrium propagation) seeks learning algorithms that could work in neural tissue.
Hardware Trends
As Moore's Law slows, we're seeing specialized hardware for AI: TPUs, neuromorphic chips, analog accelerators. These might favor different algorithms than GPUs do. Neuromorphic computing, for example, naturally implements spike-based learning rules different from gradient descent.
Beyond Gradient Descent
Some researchers are exploring entirely different paradigms:
- Discrete optimization: Directly optimize discrete structures without relaxation
- Symbolic regression: Search for explicit formulas rather than fitting parameters
- Program synthesis: Generate code that solves the task
- Energy-based models: Use physics-inspired dynamics rather than explicit optimization
Will gradient descent remain dominant? Probably for the near future. The ecosystem around it - frameworks, hardware, research expertise - is too entrenched. But eventually, something better may emerge. That's the nature of progress.
When NOT to Use Gradient Descent
Gradient descent is powerful, but it's not always the right tool. Understanding its limitations is as important as understanding its strengths.
Non-differentiable functions: Gradient descent needs gradients. If your function has discontinuities or isn't differentiable, you can't compute gradients. Examples include:
- Decision trees (discrete splits)
- Integer programming (discrete variables)
- Sort and rank operations (piecewise constant)
- Some regularization terms (L1 has discontinuity at zero)
For these, techniques like evolutionary algorithms, reinforcement learning, or subgradient methods might be better choices.
Discrete optimization: If your parameters are discrete (integers, categories), gradients don't exist. You can't take "half a step" from category A to category B. Techniques like genetic algorithms, simulated annealing, or integer programming are more appropriate.
Problems with closed-form solutions: Linear regression actually has a closed-form solution - you can compute the optimal weights directly without iteration: w = (X^T X)^(-1) X^T y. Using gradient descent for such problems is like taking a flight when you could teleport. But for neural networks, no closed-form solution exists, so gradient descent is our best option.
Very noisy gradients: If your gradient estimates are extremely noisy (high variance), gradient descent might never converge. This happens in some reinforcement learning settings where returns have high variance. Variance reduction techniques or alternative optimization methods might be needed.
Extremely non-convex landscapes: While gradient descent often works surprisingly well on non-convex problems, there are pathological cases where it fails completely. Some optimization problems have loss surfaces that are essentially random - no local structure to exploit. For these, random search or specialized heuristics might work better.
Practical Tips: Lessons from the Trenches
If you're just starting to train models, here are some things I've learned (often the hard way):
Start with Adam or AdamW
Don't overthink the optimizer choice at first. Adam or AdamW with default hyperparameters (lr=1e-3, betas=(0.9, 0.999)) is a solid starting point for most problems. Once you have a baseline, you can experiment with others if needed. SGD with momentum sometimes generalizes better for vision tasks, but only after careful tuning.
Normalize Your Inputs
If your features have wildly different scales (e.g., age 0-100 vs. salary 0-1000000), gradient descent will struggle. The salary feature will dominate the gradients, and the learning rate appropriate for salary will be terrible for age. Normalize everything to similar scales - typically zero mean and unit variance (standardization) or scaled to [0, 1] (normalization).
Monitor the Loss Curve
The loss curve is your window into training. If it's not going down, something is wrong - check your data, your model, your learning rate. If it's going down too slowly, try a larger learning rate. If it's oscillating wildly, try a smaller one. Plateaus might indicate a need for LR decay or a different optimizer.
Use Learning Rate Schedules
Start with a higher rate to make quick progress, then reduce it as training progresses to fine-tune. Cosine annealing and one-cycle policies often work well. The learning rate is probably the most important hyperparameter to tune, so invest time in getting it right.
Watch for Overfitting
If training loss goes down but validation loss goes up, you're memorizing the training data instead of learning general patterns. This is overfitting. Solutions:
- Get more training data
- Add regularization (L2, )
- Use early stopping (stop when validation loss starts increasing)
- Reduce model capacity (fewer parameters)
Check Your Gradients
If training isn't working, verify your gradients are correct. Numerical gradient checking (comparing analytical gradients to finite difference approximations) can catch bugs in backpropagation implementations. Also watch for:
- Vanishing gradients: Gradients become very small in early layers
- Exploding gradients: Gradients become very large, causing NaNs
- Dead gradients: Some neurons never learn (always output zero)
Initialize Properly
Bad initialization can make training impossible. Don't initialize all weights to the same value (symmetry problem). Small random values work, but techniques like and set the scale based on layer size, helping gradients flow properly.
Batch Size Matters
Small batches (32-128) often generalize better than large batches. They provide regularization through noise and tend to find flatter minima. Large batches train faster per epoch (better GPU utilization) but may require learning rate adjustment and often generalize worse. The "linear scaling rule" suggests increasing learning rate linearly with batch size, up to a point.
Don't Trust Default Hyperparameters Blindly
Default hyperparameters are defaults for a reason - they work okay on many problems but aren't optimal for any specific one. Spend time tuning learning rate, batch size, and regularization for your specific problem. Even small improvements in hyperparameters can make the difference between a model that doesn't work and one that does.
Conclusion
When I started digging back into machine learning, I expected complexity. Sophisticated algorithms, advanced mathematics, intricate techniques. And yes, there's plenty of that. But at the core of it all is this surprisingly simple idea: take small steps in the direction that reduces your error.
That's gradient descent. Follow the slope downhill. It's almost embarrassingly straightforward. And yet it powers every neural network, every language model, every image classifier, every recommendation system. The same basic algorithm that fits a line to points also trains GPT-4. The scale is different, but the principle is identical.
Like Merkle trees in my previous article, gradient descent is another example of how simple ideas, applied correctly, can achieve remarkable things. It's not about complexity - it's about finding the right abstraction. A hill-climbing metaphor that happens to be mathematically tractable, computationally efficient, and universally applicable.
I'm still working on my Gujarati Q&A app. Every time I fine-tune a model, every time I watch the loss curve go down, I think about gradient descent. Millions of tiny steps, each one making the model slightly better. It's a beautiful thing.
And every time I use ChatGPT or see an AI-generated image, I think about the scale. Trillions of parameters, quintillions of gradient computations, all following that same simple rule: go downhill. Somewhere in the past, a simple algorithm took billions of small steps to make that possible.
So next time you ask an AI a question, remember the gradient descent running silently in the background, optimizing, learning, improving - one small step at a time.