Machine Learning 101
This post is an adaptation from a talk given at artsec and a workshop given at a few other places (Buenos Aires Media Party 2015, MozFest 2015). The goal is to provide an intuition on machine learning for people without advanced math backgrounds - given that machine learning increasingly rules everything around us, everyone should probably have some mental model of how it works.
The materials for the workshop are available on GitHub.
What is "machine learning" anyways?
There are plenty of introductory machine learning tutorials and courses available, but they tend to focus on cursory tours of algorithms - here's linear regression, here's SVMs, here's neural networks, etc - without building an intuition of what machine learning is actually doing under the hood. This intuition is obvious to practitioners but not for people new to the field. To get the full value of the other machine learning resources out there, that intuition really helps.
But first, to clarify a bit what machine learning is: a basic working definition could be "algorithms for computers to learn patterns". Pattern recognition is something that people are very good at, but difficult for computers to do.
Here we'll go through walkthrough a very simple machine learning task which is prototypical of many real-world machine learning problems. First we'll go through it by hand, noting where our human superpower of pattern recognition comes into play, and then think about how we can translate what we did into something a computer is capable of executing.
Learning functions by hand
A common machine learning goal is to take a dataset and learn the pattern (i.e. relationship) between different variables of the data. Then that pattern can be used to predict values of those variables for new data.
Consider the following data:
If I were to ask you to describe the data by a pattern you see in that data, you'd likely draw a line. It is quite obvious to us that, even though the data points don't fall exactly in a line, a line seems to satisfactorily represent the data's pattern.
But a drawn line is no good - what are we supposed to do with it? It's hard to make use of that in a program.
A better way of describing a pattern is as a mathematical equation (i.e. a function). In this form, we can easily plug in new inputs to get predicted outputs. So we can restate our goal of learning a pattern as learning a function.
You may remember that lines are typically expressed in the form:
$$
y = mx + b
$$
As a refresher:
Lines are uniquely identified by values of
Variables like
So how can we find the right values of
Trial-and-error seems like a reasonable approach. Let's start with
The line is still quite far from the data, so let's try lowering the slope to
The line's closer, but still far from the data, so we can try lowering the slope again. Then we can check the resulting line, and continue adjusting until we feel satisfied with the result. Let's say we carry out this trial-and-error and end up with
So now we have a function,
But this was quite a laborious process. Can we get a computer to do this for us?
Learning functions by computer
The basic approach we took by hand was just to:
- Pick some random values for
¦m¦ and¦b¦ - Compare the resulting line to the data to see if it's a good fit
- Go back to 1 if the resulting line is not a good fit
This is a fairly repetitive process and so it seems like a good candidate for a computer program. There's one snag though - we could easily eyeball whether or not a line was a "good" fit. A computer can't eyeball a line like that.
So we have to get a bit more explicit about what a "good" fit is, in terms that a computer can make use of. To put it another way, we want the computer to be able to calculate how "wrong" its current guesses for
Let's think for a moment about what we would consider a bad-fitting line. The further the line was from the dataset, the worse we consider it. So we want lines that are close to our datapoints. We can restate this in more concrete terms.
First, let's notate our line function guess as
We can look at each datapoint and calculate how far it is from our line guess with
There are many ways we can do this, but a common way is to square all of these errors (only the magnitude of the error is important) and then take their mean:
$$
\frac{\sum (y - \hat y)^2 }{n}
$$
Where
It also helps to think about this as:
$$
\frac{\sum (y - f(x))^2 }{n}
$$
To make it a bit clearer that the important part here is our guess at the function
A function like this which calculates the "wrongness" of our current line guess is called a loss function (or a cost function). Clearly, we want to find a line (that is, values of
Another way of saying this is that we want to find parameters which minimize this loss function. This is basically what we're doing when we eyeball a "good" line.
When we guessed the line by hand, we just iteratively tried different
Changing
Remember that a derivative of a function tells us the rate of change at a specific point in that function. We could compute the derivative of the loss function with respect to our current guesses for
Then it's just a matter of repeating this procedure - with our new guesses for
To summarize, we took our trial-and-error-by-hand approach and turned it into the following computer-suited approach:
- Pick some random values for
¦m¦ and¦b¦ - Use a loss function to compare our guess
¦f(x)¦ to the data - Determine how to change
¦m¦ and¦b¦ by computing the derivative of the loss function with respect to¦f(x)¦ - Go back to 1 and repeat until the loss function can't get any lower (or until it's low enough for our purposes)
A lot of machine learning is just different takes on this basic procedure. If it had to be summarized in one sentence: an algorithm learns some function by figuring out parameters which minimize some loss function.
Different problems and methods involve variations on these pieces - different loss functions, different ways of changing the parameters, etc - or include additional flourishes to help get around the problems that can crop up.
Beyond lines
Here we worked with lines, i.e. functions of the form
To be honest, there is quite a bit more to machine learning than just this - figuring out how to best represent data is another major concern, for example. But hopefully this intuition will help provide some direction in a field which can feel like a disconnected parade of algorithms.