Nov 21, 2017

Ordinary least squares, ℓ² (ridge), and ℓ¹ (lasso) linear regressions

Preface

I wrote this in 2017, and am posting it now in 2021. I was surprised how difficult it was to find complete information about linear regressions in one place: the derivations of the gradients, how they get their properties (e.g., lasso’s sparsity requiring coordinate descent), and some simple code to implement them. I tried to be careful about vector shapes and algebra, but there are probably still minor errors, which are of course my own.

One big goof I had was running this on MNIST, which ought to be treated as a classification problem per class (e.g., with logistic regression), rather than trying to regress each digit to a number (e.g., the digit “1” to the number 1, and the digit “5” to the number 5). I should have ran this code on a true regression dataset instead, where you do want real numbers (rather than class decisions) as output.

However, the silver lining is that after this goof, I was in a computer architecture class where we needed to run MNIST classification on FPGAs, and the starter code had made exactly this same mistake—they were doing linear instead of logistic regression! Making that simple switch resulted in such an accuracy boost that the classifier became one of the pareto optimal ones.

The repository for this project, which contains the full writeup below, as well as simple pytorch code to implement it, is here:

rndjam1

Regression derivations (+ basic code running on MNIST): ordinary least squares, ridge (ℓ²), and lasso (ℓ¹).

Enjoy!

– Max from 2021

Goal

Build linear regression for MNIST from scratch using pytorch.

Data splits

MNIST (csv version) has a 60k/10k train/test split.

I pulled the last 10k off of train for a val set.

My final splits are then 50k/10k/10k train/val/test.

Viewing an image

Here’s an MNIST image:

the first mnist datum

Here it is expanded 10x:

the first mnist datum, expanded

Data loading: CSV vs binary (“tensor”)

y-axis is seconds taken to load the file; lower is better. Result: binary is way faster.

data loading speeds, csv vs binary

Naive regression to scalar

In this we regress each image to a scalar that is the number represented in that image. For example, we regress the image the first MNIST datum to the number 5.

Disclaimer: this is a suboptimal approach. If you’re going to treat was is really a classification problem (like MNIST) as regression, you should regress to each class independently (i.e., do 10 regression problems at once instead of a single regression). Explaining why would take math that I would have to talk to people smarter than me to produce. I think the intuition is that you’re making the learning problem harder by forcing these distinct classes to exist as points in a 1D real space, when they really have no relation to each other. This is better treated as a logistic regression problem.

However: (a) if you’re confused like I was, you might try it, (b) if you’re bad at math like me, it’s simpler to start out with a “normal” regression than 10 of them, (c) I’m kind of treating this like a notebook, so might as well document the simple → complex progression of what I tried.

So here we go.

Notation

Definitions:

definitions

Math reminders and my notation choices:

math reminders

NB: While the derivative of a function f : ℝn → ℝ is technically a row vector, people™ have decided that gradients of functions are column vectors, which is why I have transposes sprinkled below. (Thanks to Chris Xie for explaining this.)

Ordinary least squares (OLS)

Loss (average per datum):

least squares loss

Using the average loss per datum is nice because it is invariant of the dataset (or (mini)batch) size, which will come into play when we do gradient descent. Expanding the loss function out for my noob math:

least squares loss expanded

Taking the derivative of the loss function with respect to the weight vector:

least squares loss expanded derivative

We can set the gradient equal to 0 (the zero vector) and solve for the analytic solution (omitting second derivative check):

least squares analytic expanded

Doing a little bit of algebra to clean up the gradient, we’ll get our gradient for gradient descent:

least squares gradient

We can plot the loss as we take more gradient descent steps:

ols gradient descent linear plot

… but it’s hard to see what’s happening. That’s because the loss starts so high and the y-axis is on a linear scale. A log scale is marginally more informative:

ols gradient descent log plot

To instead do coordinate descent, we optimize a single coordinate at a time, keeping all others fixed. We take the derivative of the loss function with respect to a single weight:

least squares derivative single weight

Setting the derivative equal to zero, we can solve for the optimal value for that single weight:

least squares derivative single weight zero

However, this is an expensive update to a single weight. We can speed this up. If we define the residual,

residual

then we can rewrite the inner term above as,

least squares residual rewrite

and, using (t) and (t+1) to clarify old and new values for the weight, rewrite the single weight optimum as:

least squares coord descent

After updating that weight, r is immediately stale, so we must update it as well:

least squares coord descent r update

We can compute an initial r and we can precompute all of the column norms (the denominator) because they do not change. That means that each weight update involves just the n-dimensional vector dot product (the numerator) and updating r (n-dimensional operations). Because of this, one full round of coordinate descent (updating all weight coordinates once) is said to have the same update time complexity as one step of gradient descent (O(nd)).

However, I found that in practice, one step of (vanilla) gradient descent is much faster. I think this is because my implementation of coordinate descent requires moving values to and from the GPU (for bookkeeping old values), whereas gradient descent can run entirely on the GPU. I’m not sure if I can remedy this. With that said, coordinate descent converges with 10x fewer iterations.

ols coordinate descent plot

But how well do we do in regressing to a scalar with OLS?

ols accuracy

Not very well.

Ridge regression (RR)

Loss:

ridge loss

NB: For all regularization methods (e.g., ridge and lasso), we shouldn’t be regularizing the weight corresponding to the bias term (I added as an extra feature column of 1s). You can remedy this by either (a) centering the ys and omitting the bias term, or (b) removing the regularization of the bias weight in the loss and gradient. I tried doing (b) but I think I failed (GD wasn’t getting nearly close enough to analytic loss), so I’ve left the normalization in there for now (!).

Derivative:

(Being a bit more liberal with my hand waving of vector and matrix derivatives than above)

ridge derivative

Analytic:

NB: I think some solutions combine n into λ because it looks cleaner. In order to get the analytic solution and gradient (descent) to reach the same solution, I needed to be consistent with how I applied n, so I’ve left it in for completeness.

ridge analytic

Gradient:

(Just massaging the derivative we found a bit more.)

ridge gradient

Coordinate descent:

The derivative of the regularization term with respect to a single weight is:

ridge cd 0

with that in mind, the derivative of the loss function with respect to a single weight is:

ridge cd 1

In setting this equal to 0 and solving, I’m going to do some serious hand waving about “previous” versus “next” values of the weight. (I discovered what seems (empirically) to be the correct form by modifying late equations of the Lasso coordinate descent update, but I’m not sure the correct way to do the derivation here.) We’ll also make use of the residual residual.

ridge cd 2

As above, we update the residual after each weight update:

residual update

Lasso

Loss:

lasso loss

Derivative:

lasso derivative part 1

Focusing on the final term, we’ll use the subgradient, and pick 0 (valid in [-1, 1]) for the nondifferentiable point. This means we can use sgn(x) as the “derivative” of |x|.

lasso derivative part 2

Substitute in to get the final term for the (sub)gradient:

lasso derivative part 3

NB: There’s no soft thresholding (sparsity-encouraging) property of LASSO when you use gradient descent. You need something like coordinate descent to get that. Speaking of which…

Coordinate descent:

lasso cd 1

setting this = 0, and again using the residual residual, we have:

lasso cd 2

NB: I think that here (and below) we might really be saying that 0 is in the set of subgradients, rather than that it equals zero.

There’s a lot going on. Let’s define two variables to clean up our equation:

lasso cd 3

From this, we can more clearly see the solution to this 1D problem:

lasso cd 4

This solution is exactly the soft threshold operator:

lasso cd 5

Rewriting this into its full form:

lasso cd 6

As with coordinate descent above, we need to update the residual r after each weight update (skipping the derivation; same as above for OLS):

residual update

Acknowledgements

Many thanks to Chris Xie and John Thickstun for helping me out with math. All errors are my own.