How to Determine What Machine Learning Model to Use

With all of the different machine learning models out there (unsupervised learning, supervised learning, and reinforcement learning), how do you go about deciding which model to use for a particular problem?

One approach is to try every possible machine learning model, and then to examine which model yields the best results. The problem with this approach is it could take a VERY long time. There are dozens of machine learning algorithms, and each one has different run times. Depending on the data set, some algorithms may take hours or even days to complete.

Another risk of doing the “try-all-models” approach is that you also might end up using a machine learning algorithm on a type of problem that is not really a good fit for that particular algorithm. An analogy would be like using a hammer to tighten a screw. Sure, a hammer is a useful tool but only when used for its intended purpose. If you want to tighten a screw, use a screwdriver, not a hammer.

When deciding on what type of machine learning algorithm to use, you have to first understand the problem thoroughly and then decide what you want to achieve. Here is a helpful framework that can be used for algorithm selection:

Are you trying to divide an unlabeled data set into groups such that each group has similar attributes (e.g. customer segmentation)?

If yes, use a clustering algorithm (unsupervised learning) like k-means, hierarchical clustering, or Gaussian Mixture Models.

Are you trying to predict a continuous value given a set of attributes (e.g. house price prediction)?

If yes, use a regression algorithm (supervised learning), like linear regression.

Are we trying to predict discrete classes (e.g. spam/not spam)? Do we have a data set that is already labeled with the classes?

If yes to both questions, use a classification algorithm (supervised learning) like Naive Bayes, K-Nearest Neighbors, logistic regression, ID3, neural networks, or support vector machines.

Are you trying to reduce a large number of attributes to a smaller number of attributes?

Use a dimensionality reduction algorithm, like stepwise forward selection or principal components analysis.

Do you need an algorithm that reacts to its environment, continuously learning from experience, the way humans do (e.g. autonomous vehicles and robots)?

If yes, use reinforcement learning methods.

For each of the questions above, you can ask follow-up questions to hone in on the appropriate algorithm to use on that type.

For example:

  • Do we need an algorithm that can be built, trained, and tested quickly?
  • Do we need a model that can make fast predictions?
  • How accurate does the model need to be?
  • Is the number of attributes greater than the number of instances?
  • Do we need a model that is easy to interpret? 
  • How scalable a model do we need?
  • What evaluation criteria is important in order to meet business needs?
  • How much data preprocessing do we want to do?

Here is a really useful flowchart from Microsoft that presents different ways to help one to decide what algorithm to use when:

machine-learning-decision-chart
Source: Microsoft

Here is another useful flowchart from SciKit Learn.

scikit-learn

Slide 11 of this link shows the interpretability vs. accuracy tradeoffs for the different machine learning models.

This link provides a quick rundown of the different types of machine learning models.

Advantages of K-Means Clustering

The K-means clustering algorithm is used to group unlabeled data set instances into clusters based on similar attributes. It has a number of advantages over other types of machine learning models, including the linear models, such as logistic regression and Naive Bayes.

Here are the advantages:

Unlabeled Data Sets

A lot of real-world data comes unlabeled, without any particular class. The benefit of using an algorithm like K-means clustering is that we often do not know how instances in a data set should be grouped. 

For example, consider the problem of trying to group viewers of Netflix into clusters based on similar viewing behavior. We know that there are clusters, but we do not know what those clusters are. Linear models will not help us at all with these sorts of issues.

Nonlinearly Separable Data

Consider the data set below containing a set of three concentric circles. It is nonlinearly separable. In other words, there is no straight line or plane that we could draw on the graph below that can easily discriminate the colored classes red, blue, and green. Using K-means clustering and converting the coordinate system below from Cartesian coordinates to Polar coordinates, we could use the information about the radius to create concentric clusters.

concentric-clusters

Simplicity

The meat of the K-means clustering algorithm is just two steps, the cluster assignment step and the move centroid step. If we’re looking for an unsupervised learning algorithm that is easy to implement and can handle large data sets, K-means clustering is a good starting point. 

Availability

Most of the popular machine learning packages contain an implementation of K-means clustering.

Speed

Based on my experience using K-means clustering, the algorithm does its work quickly, even for really big data sets. 

Advantages of Decision Trees

Decision tree algorithms such as ID3 provide a convenient way to show all of the possible outcomes of a decision. Decision trees can be used for either classification or regression. Below are some of the advantages of using decision trees as opposed to other types of machine learning models. 

Simplicity

One of the things that I like about decision trees is that you can easily explain the model to somebody who has a non-technical background. Decision trees create straightforward if-then-else rules which could be communicated to a boss, project manager, product manager, or outside stakeholder.

Contrast decision trees with other more black box-like machine learning algorithms such as logistic regression, neural networks, or reinforcement learning method, and you can see that decision trees would provide a refreshing level of transparency not always common in machine learning.

No Large Data Requirement

If you take a look at an algorithm like the k nearest neighbors algorithm, which classifies an unseen instance based on instances that are most similar to that instance, you need a lot of data in order to get accurate results. The more data you have, the better.

However, there may be certain instances or certain problems when a lot of data is not available. The benefit of using decision trees is that you do not not need a lot of data in order to create something useful.

Best and Worst Case

In some settings, you want to be able to determine a worst case, a best case, and a management case. With a decision tree, you can easily see all of the possible outcomes. Each test instance gets put into one of the outcomes, so you pretty much know what to expect ahead of time. Even outliers don’t phase a decision tree.

Continuous and Discrete Data

Decision trees can handle both continuous and discrete data depending on which decision tree model you use. In contrast, many machine learning algorithms can only handle either continuous or discrete data, but not both.

Non-Linearity

Decision trees can capture nonlinear relationships.

Fast

Classifying a test instance is fast and just depends on the depth of the tree.

Irrelevant Attributes

Because of the way decision trees are computed using the Information Gain, irrelevant attributes are handled with ease.