The scikit-learn DummyRegressor
class implements several strategies for random guessing, which can serve as baseline for regressors. The strategies are as follows:
mean
: This predicts the mean of the training set.median
: This predicts the median of the training set.quantile
: This predicts a specified quantile of the training set when provided with the quantile
parameter. We will apply this strategy by specifying the first and third quartile.constant
: This predicts a constant value that is provided by the user.We will compare the dummy regressors with the regressors from Chapter 9, Ensemble Learning and Dimensionality Reduction, using R-squared, MSE, MedAE, and MPE.
import numpy as np from sklearn.dummy import DummyRegressor import ch10util from sklearn import metrics import dautil as dl from IPython.display import HTML
y_test = np.load('temp_y_test.npy') X_train = np.load('temp_X_train.npy') X_test = np.load('temp_X_test.npy') y_train = np.load('temp_y_train.npy')
mean = DummyRegressor() median = DummyRegressor(strategy='median') q1 = DummyRegressor(strategy='quantile', quantile=0.25) q3 = DummyRegressor(strategy='quantile', quantile=0.75) preds = ch10util.temp_preds() for reg in [mean, median, q1, q3]: reg.fit(X_train, y_train) preds.append(reg.predict(X_test))
r2s = [metrics.r2_score(p, y_test) for p in preds] mses = [metrics.mean_squared_error(p, y_test) for p in preds] maes = [metrics.median_absolute_error(p, y_test) for p in preds] mpes = [dl.stats.mpe(y_test, p) for p in preds] labels = ch10util.temp_labels() labels.extend(['mean', 'median', 'q1', 'q3'])
sp = dl.plotting.Subplotter(2, 2, context) ch10util.plot_bars(sp.ax, r2s, labels) sp.label() ch10util.plot_bars(sp.next_ax(), mses, labels) sp.label() ch10util.plot_bars(sp.next_ax(), maes, labels) sp.label() ch10util.plot_bars(sp.next_ax(), mpes, labels) sp.label() sp.fig.text(0, 1, ch10util.regressors()) HTML(sp.exit())
Refer to the following screenshot for the end result:
The code is in the dummy_reg.ipynb
file in this book's code bundle.
DummyRegressor
class documented at http://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html (retrieved November 2015)