Machine Learning 101

11.10.2015 | etc

This post is an adaptation from a talk given at artsec and a workshop given at a few other places (Buenos Aires Media Party 2015, MozFest 2015). The goal is to provide an intuition on machine learning for people without advanced math backgrounds - given that machine learning increasingly rules everything around us, everyone should probably have some mental model of how it works.

The materials for the workshop are available on GitHub.

What is "machine learning" anyways?

There are plenty of introductory machine learning tutorials and courses available, but they tend to focus on cursory tours of algorithms - here's linear regression, here's SVMs, here's neural networks, etc - without building an intuition of what machine learning is actually doing under the hood. This intuition is obvious to practitioners but not for people new to the field. To get the full value of the other machine learning resources out there, that intuition really helps.

But first, to clarify a bit what machine learning is: a basic working definition could be "algorithms for computers to learn patterns". Pattern recognition is something that people are very good at, but difficult for computers to do.

Here we'll go through walkthrough a very simple machine learning task which is prototypical of many real-world machine learning problems. First we'll go through it by hand, noting where our human superpower of pattern recognition comes into play, and then think about how we can translate what we did into something a computer is capable of executing.

Learning functions by hand

A common machine learning goal is to take a dataset and learn the pattern (i.e. relationship) between different variables of the data. Then that pattern can be used to predict values of those variables for new data.

Consider the following data:

Some data
Some data

If I were to ask you to describe the data by a pattern you see in that data, you'd likely draw a line. It is quite obvious to us that, even though the data points don't fall exactly in a line, a line seems to satisfactorily represent the data's pattern.

But a drawn line is no good - what are we supposed to do with it? It's hard to make use of that in a program.

A better way of describing a pattern is as a mathematical equation (i.e. a function). In this form, we can easily plug in new inputs to get predicted outputs. So we can restate our goal of learning a pattern as learning a function.

You may remember that lines are typically expressed in the form:

$$ y = mx + b $$

As a refresher: $y$ is the output, $x$ is the input, $m$ is the slope, and $b$ is the intercept.

Lines are uniquely identified by values of $m$ and $b$:

Some lines
Some lines

Variables like $m$ and $b$ which unique identify a function are called parameters (we also say that $m$ and $b$ paramterize the line). So finding a particular line means finding particular values of $m$ and $b$. So when we say we want to learn a function, we are really saying that we want to learn the parameters of a function, since they effectively define the function.

So how can we find the right values of $m, b$ for the line that fits our data?

Trial-and-error seems like a reasonable approach. Let's start with $m=12, b=2$ (these values were picked arbitrarily, you could start with different values too).

Trying $m=12, b=2$
Trying $m=12, b=2$

The line is still quite far from the data, so let's try lowering the slope to $m=8$.

Trying $m=8, b=2$
Trying $m=8, b=2$

The line's closer, but still far from the data, so we can try lowering the slope again. Then we can check the resulting line, and continue adjusting until we feel satisfied with the result. Let's say we carry out this trial-and-error and end up with $m=4.6, b=40$.

Trying $m=4.6, b=40$
Trying $m=4.6, b=40$

So now we have a function, $y = 4.6x + b$, and if we got new inputs we could predict outputs that would follow the pattern of this dataset.

But this was quite a laborious process. Can we get a computer to do this for us?

Learning functions by computer

The basic approach we took by hand was just to:

  1. Pick some random values for $m$ and $b$
  2. Compare the resulting line to the data to see if it's a good fit
  3. Go back to 1 if the resulting line is not a good fit

This is a fairly repetitive process and so it seems like a good candidate for a computer program. There's one snag though - we could easily eyeball whether or not a line was a "good" fit. A computer can't eyeball a line like that.

So we have to get a bit more explicit about what a "good" fit is, in terms that a computer can make use of. To put it another way, we want the computer to be able to calculate how "wrong" its current guesses for $m$ and $b$ are. We can define a function to do so.

Let's think for a moment about what we would consider a bad-fitting line. The further the line was from the dataset, the worse we consider it. So we want lines that are close to our datapoints. We can restate this in more concrete terms.

First, let's notate our line function guess as $f(x)$. Each datapoint we have is represented as $(x, y)$, where $x$ is the input and $y$ is the output. For a datapoint $(x, y)$, we predict an output $\hat y$ based on our line function, i.e. $\hat y = f(x)$.

We can look at each datapoint and calculate how far it is from our line guess with $y - \hat y$. Then we can combine all these distances (which are called errors) in some way to get a final "wrongness" value.

There are many ways we can do this, but a common way is to square all of these errors (only the magnitude of the error is important) and then take their mean:

$$ \frac{\sum (y - \hat y)^2 }{n} $$

Where $n$ is the number of datapoints we have.

It also helps to think about this as:

$$ \frac{\sum (y - f(x))^2 }{n} $$

To make it a bit clearer that the important part here is our guess at the function $f(x)$.

A function like this which calculates the "wrongness" of our current line guess is called a loss function (or a cost function). Clearly, we want to find a line (that is, values of $m$ and $b$) which are the "least wrong".

Another way of saying this is that we want to find parameters which minimize this loss function. This is basically what we're doing when we eyeball a "good" line.

When we guessed the line by hand, we just iteratively tried different $m$ and $b$ values until it looked good enough. We could use a similar approach with a computer, but when we did it, we again could eyeball how we should change $m$ and $b$ to get closer (e.g. if the line is going above the datapoints, we know we need to lower $b$ or lower $m$ or both). How then can a computer know in which direction to change $m$ and/or $b$?

Changing $m$ and $b$ to get a line that is closer to the data is the same as changing $m$ and $b$ to lower the loss. It's just a matter of having the computer figure out in what direction changes to $m$ and/or $b$ will lower the loss.

Remember that a derivative of a function tells us the rate of change at a specific point in that function. We could compute the derivative of the loss function with respect to our current guesses for $m$ and $b$. This will inform us as to which direction(s) we need to move in order to lower the loss, and gives us new guesses for $m$ and $b$.

Then it's just a matter of repeating this procedure - with our new guesses for $m$ and $b$, use the derivative of the loss function to figure out what direction(s) to move in and get new guesses, and so on.

To summarize, we took our trial-and-error-by-hand approach and turned it into the following computer-suited approach:

  1. Pick some random values for $m$ and $b$
  2. Use a loss function to compare our guess $f(x)$ to the data
  3. Determine how to change $m$ and $b$ by computing the derivative of the loss function with respect to $f(x)$
  4. Go back to 1 and repeat until the loss function can't get any lower (or until it's low enough for our purposes)

A lot of machine learning is just different takes on this basic procedure. If it had to be summarized in one sentence: an algorithm learns some function by figuring out parameters which minimize some loss function.

Different problems and methods involve variations on these pieces - different loss functions, different ways of changing the parameters, etc - or include additional flourishes to help get around the problems that can crop up.

Beyond lines

Here we worked with lines, i.e. functions of the form $y = mx + b$, but not all data fits neatly onto lines. The good news is that this general approach applies to sundry functions, both linear and nonlinear. For example, neural networks can approximate any function, from lines to really crazy-looking and complicated functions.

To be honest, there is quite a bit more to machine learning than just this - figuring out how to best represent data is another major concern, for example. But hopefully this intuition will help provide some direction in a field which can feel like a disconnected parade of algorithms.