Okay — for someone who is looking for a reason for “learning rate”, can I assume you already have some idea on gradient descent? If not, there are plenty of resources online including this.
Now you probably are looking for a reason why the gradients are factored by learning rate before applied to the parameters. I will try to explain here.
Let’s learn illustrating examples
Consider a case where we have only 1 parameter, say x, which we want to update through training. Our ultimate objective at the end of the training is to come up with a value for x, which produces the least possible value for loss, 0 if possible.
Let’s make things a bit more concrete. Notice that — loss is a function of x, which simply means— loss can be varied by changing x. Let us imagine:
Find the “impact”
Now, to minimize loss by manipulating x, we need to know what impact does x has on loss. Meaning, if we increase x a little bit, will it increase or decrease the loss. To answer that question, we need to find the slope of the function (which is defined by rate of change of the function for an infinitesimally small change in the input — x in this case). Don’t worry if it sounds mouthful — things will be clear as we go.
Slope if not necessarily constant
Notice that, this slope is not necessarily a fixed number. For a straight line, it is a constant alright since rate/direction of change of straight line is the same everywhere. Meaning, if the direction of the line is upward at x=4, it will be upward at x = -100, x = +100, and for any value of x.
But in a more complicated case, for example, for our imaginary loss function (loss = x²), the direction of change is not same everywhere.
For example, 2²= 4, and 2.0001² is 4.00040001. Since 4.00040001 is bigger than 4, we can say that x² has upward direction at x = 2.
But notice that — (-2)² = 4, but (-2 + .0001)² = 1.9999² = 3.99960001. So, x² has downward direction at x = -2.
Deriving slope function
Can we derive a formula where we can find slope of x² at any point by putting only value of x? Yes, it is 2x. Recall from your high school math, we derive slope function of any curve by applying differential calculus, and
And as another example, the slope function of the curve, f(x) = 5x³ is:
What information we get from slope
Going back to our example, loss = f(x) = x², when x = 2, slope 2x = 4 (positive). And when x= -2, slope = -4 (negative). We understand from here why x² is upward at x=2 (since slope is positive here), but downward at x=-2 (since slope is negative here).
Notice that, in addition to direction, slope gives us another information — a number, which indicates the rate of change. For example, slope is 10 at x = 5, which is bigger than slope at x=2 (4). As we see from the graph below, it is indeed steeper at x =5 (going upward more rapidly) than it is at x = 2.
I got the slope, what to do now?
Consider we are at x=2. Remember our ultimate objective is to find the point where loss is around 0. From the graph, it is easy to understand that we have to take exactly 2 steps to the left (making x = 0), where it will give us the minimum possible value of loss (0). But in reality, the loss function depends on many more parameters (for example, in GPT-3, there are 175 billion parameters!). It is not as easy to draw the actual curve as finding slope/gradient for a parameter at a specific point.
However, if we knew that (0,0) was our final destiny, look how we had taken a step toward (0,0) — from different initial points:
- When we are in x = -2, slope = -4: Go right (to the direction where x is increasing) along the curve
- When we are in x = 2, slope = 4: Go left (to the direction where x is increasing) along the curve
- When we are in x = 5, slope =10: Go left along the curve, but take bigger step than we did for x = 4 (since longer distance to go).
What do we understand from the 3 cases above? We are going right when slope is negative, and going left when slope is positive. Meaning we are taking the opposite direction of the slope. And our step is proportional to the slope value.
Now, when we are at x = 2, our slope is positive here (2x = 4), we know we have to go left. But exactly how much left?
Take 1 step to the left for example, meaning decrease x by 1. So, now, x = 2 - 1=1. Here loss = x² = 1² = 1, and slope = 2x = 2 * 1 = 2. Loss at this point is lower than previous loss (4), which tells us that we made the right move.
Slope is positive here (2x =2), so next step will be toward left as well. Now we could get greedy at this point and take a bigger step, say 5 steps. Meaning x = 1 – 5= -4. Here loss is 16. Much bigger than before (1). So, we understood that we took a wrong decision. We calculate slope, which is 2x = 2 * (-4) = -8. Now since slope is negative, we know we have to go to the right direction along the curve to make things right.
We understood that, taking such big steps are not wise. We could rather take smaller steps. Say, if we take a step of 0.0001, it will either gradually reach 0, or at the worst case keep jumping between -0.0001 and 0.0001 until training is finished (which is okay — since in real life, we may never reach the ultimate 0 loss, somewhere close to 0 is fantastic).
Since in real life, we will have multiple parameters, instead of taking an exact step of 0.0001, for each parameter, we take a step 0.0001 multiplied by the slope (or gradient) of the parameter at the current point. It will make sure all parameters are descending an amount proportional to their gradient.
Where is the goddam learning rate?
Oh — I forgot to tell you that this small value (0.0001 in the example) we are talking about — is our learning rate.
Why not a smaller value like 0.000000000000001?
Okay, so 0.0001 is a good decision (at least for our loss = x² example). Would it be better if we had taken 0.01, or 0.000000000000001? As you understood by now, big steps will make us jumping from one direction to another (overshooting), and an extremely small step will make the training slower requiring more iterations of training. So, we decide the value depending on the problem domain and/or our experience.
Where to go from here?
Please keep in mind that, learning rate can either remain constant or in better algorithms, it can be updated based on learning progress. Look for “learning rate decay” to know more about it.