Decision trees are supervised models that can either perform regression or classification.
Let's take a look at some major league baseball player data from 1986-1987. Each dot represents a single player in the league:
The preceding data is our training data. The idea is to build a model that predicts the salary of future players based on Years and Hits. A decision tree aims to make splits on our data in order to segment the data points that act similarly to each other, but differently to the others. The tree makes multiples of these splits in order to make the most accurate prediction possible. Let's see a tree built for the preceding data:
Let's read this from top to bottom:
true
, you follow the left branch. When a splitting rule is false
, you follow the right branch. So for a new player, if they have been playing for less than 4.5 years, we will go down the left branch.This tree doesn't just give us predictions; it also provides some more information about our data:
Modern decision tree algorithms tend to use a recursive binary splitting approach:
For classification trees, the algorithm is very similar with the biggest difference being the metric we optimize over. Because MSE only exists for regression problems, we cannot use it. However, instead of accuracy, classification trees optimize over either the Gini index or entropy.
Similarly to a regression tree, a classification tree is built by optimizing over a metric (in this case, the Gini index) and choosing the best split to make this optimization. More formally, at each node, the tree will take the following steps:
Let's say that we are predicting the likelihood of death aboard a luxury cruise ship given demographic features. Suppose we start with 25 people, 10 of whom survived, and 15 of whom died:
Before split |
All |
---|---|
Survived |
10 |
Died |
15 |
We first calculate the Gini index before doing anything:
In this example, overall classes are survived and died, illustrated in the following formula:
This means that the purity of the dataset is 0.48.
Now let's consider a potential split on gender. We first calculate the Gini index for each gender:
The following formula calculates Gini index for male and female as follows:
Once we have the Gini index for each gender, we then calculate the overall Gini index for the split on gender, as follows:
So, the gini coefficient for splitting on gender is 0.27. We then follow this procedure for three potential splits:
In this example, we would choose the gender to split on as it is the lowest Gini index!
The following table briefly summarizes the differences between classification and regression decision trees:
Regression trees |
Classification trees |
---|---|
Predict a quantitative response |
Predict a qualitative response |
Prediction is the average value in each leaf |
Prediction is the most common label in each leaf |
Splits are chosen to minimize MSE |
Splits are chosen to minimize Gini index (usually) |
Let's use scikit-learn's built-in decision tree function in order to build a decision tree:
# read in the data titanic = pd.read_csv('short_titanic.csv') # encode female as 0 and male as 1 titanic['Sex'] = titanic.Sex.map({'female':0, 'male':1}) # fill in the missing values for age with the median age titanic.Age.fillna(titanic.Age.median(), inplace=True) # create a DataFrame of dummy variables for Embarked embarked_dummies = pd.get_dummies(titanic.Embarked, prefix='Embarked') embarked_dummies.drop(embarked_dummies.columns[0], axis=1, inplace=True) # concatenate the original DataFrame and the dummy DataFrame titanic = pd.concat([titanic, embarked_dummies], axis=1) # define X and y feature_cols = ['Pclass', 'Sex', 'Age', 'Embarked_Q', 'Embarked_S'] X = titanic[feature_cols] y = titanic.Survived X.head()
Note that we are going to use class, sex, age, and dummy variables for city embarked as our features:
# fit a classification tree with max_depth=3 on all data from sklearn.tree import DecisionTreeClassifier treeclf = DecisionTreeClassifier(max_depth=3, random_state=1) treeclf.fit(X, y)
max_depth
is a limit to the depth of our tree. It means that, for any data point, our tree is only able to ask up to three questions and make up to three splits. We can output our tree into a visual format and we will obtain the following:
We can notice a few things:
Embarked_Q
was never used in any splitFor either classification or regression trees, we can also do something very interesting with decision trees, which is that we can output a number that represents each feature's importance in the prediction of our data points:
# compute the feature importances pd.DataFrame({'feature':feature_cols, 'importance':treeclf.feature_importances_})
The importance scores are an average Gini index difference for each variable, with higher values corresponding to higher importance to the prediction. We can use this information to select fewer features in the future. For example, both of the embarked variables are very low in comparison to the rest of the features, so we may be able to say that they are not important in our prediction of life or death.