The scikit-learn DummyClassifier
class implements several strategies for random guessing, which can serve as a baseline for classifiers. The strategies are as follows:
stratified
: This uses the training set class distributionmost_frequent
: This predicts the most frequent classprior
: This is available in scikit-learn 0.17 and predicts by maximizing the class prioruniform
: This uses an uniform distribution to randomly sample classesconstant
: This predicts a user-specified classAs you can see, some strategies of the DummyClassifier
class always predict the same class. This can lead to warnings from some scikit-learn metrics functions. We will perform the same analysis as we did in the Computing precision, recall, and F1 score recipe, but with dummy classifiers added.
import numpy as np from sklearn import metrics import ch10util from sklearn.dummy import DummyClassifier from IPython.display import HTML import dautil as dl
y_test = np.load('rain_y_test.npy') X_train = np.load('rain_X_train.npy') X_test = np.load('rain_X_test.npy') y_train = np.load('rain_y_train.npy')
stratified = DummyClassifier(random_state=28) frequent = DummyClassifier(strategy='most_frequent', random_state=28) prior = DummyClassifier(strategy='prior', random_state=29) uniform = DummyClassifier(strategy='uniform', random_state=29) preds = ch10util.rain_preds() for clf in [stratified, frequent, prior, uniform]: clf.fit(X_train, y_train) preds.append(clf.predict(X_test))
accuracies = [metrics.accuracy_score(y_test, p) for p in preds] precisions = [metrics.precision_score(y_test, p) for p in preds] recalls = [metrics.recall_score(y_test, p) for p in preds] f1s = [metrics.f1_score(y_test, p) for p in preds]
labels = ch10util.rain_labels() labels.extend(['stratified', 'frequent', 'prior', 'uniform']) sp = dl.plotting.Subplotter(2, 2, context) ch10util.plot_bars(sp.ax, accuracies, labels, rotate=True) sp.label() ch10util.plot_bars(sp.next_ax(), precisions, labels, rotate=True) sp.label() ch10util.plot_bars(sp.next_ax(), recalls, labels, rotate=True) sp.label() ch10util.plot_bars(sp.next_ax(), f1s, labels, rotate=True) sp.label() sp.fig.text(0, 1, ch10util.classifiers(), fontsize=10) HTML(sp.exit())
Refer to the following screenshot for the end result:
The code is in the dummy_clf.ipynb
file in this book's code bundle.
DummyClassifier
class documented at http://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyClassifier.html (retrieved November 2015)