Understanding the task by understanding the data

What is always the first step in tackling a new machine learning problem?

You are absolutely right: getting a sense of the data. The better we understand the data, the better we understand the problem we are trying to solve. In our future endeavors, this will also help us to choose an appropriate machine learning algorithm.

The first thing to realize is that the drug column is actually not a feature value like all of the other columns. Since it is our goal to predict which drug will be prescribed based on a patient's blood values, the drug column effectively becomes the target label. In other words, the inputs to our machine learning algorithm will be the blood values, age, and gender of a patient. Hence, the output will be a prediction of which drug to prescribe. Since the drug column is categorical in nature and not numerical, we know that we are facing a classification task.

Hence, it would be a good idea to remove all drug entries from the dictionaries listed in the data variable and store them in a separate variable:

  1. For this, we need to go through the list and extract the drug entry, which is easiest to do with a list comprehension:
In [2]: target = [d['drug'] for d in data]
... target
Out[2]: ['A', 'D', 'B', 'C', 'D', 'C', 'A', 'B', 'D', 'C',
'A', 'B', 'C', 'B', 'D', 'A', 'C', 'B', 'D', 'A']
  1. Since we also want to remove the drug entries in all of the dictionaries, we need to go through the list again and pop the drug key. We add ; to suppress the output since we don't want to see the whole dataset again:
In [3]: [d.pop('drug') for d in data];
  1. Sweet! Now, let's look at the data. For the sake of simplicity, we may want to focus on the numerical features first: age, K, and Na. These are relatively easy to plot using Matplotlib's scatter function. We first import Matplotlib as usual:
In [4]: import matplotlib.pyplot as plt
... %matplotlib inline
... plt.style.use('ggplot')
  1. Then, if we want to plot the potassium level against the sodium level for every data point in the dataset, we need to go through the list and extract the feature values:
In [5]: age = [d['age'] for d in data]
... age
Out[5]: [33, 77, 88, 39, 43, 82, 40, 88, 29, 53,
36, 63, 60, 55, 35, 23, 49, 27, 51, 38]

  1. We do the same for the sodium and potassium levels:
In [6]: sodium = [d['Na'] for d in data]
... potassium = [d['K'] for d in data]
  1. These lists can then be passed to Matplotlib's scatter function:
In [7]: plt.scatter(sodium, potassium)
... plt.xlabel('sodium')
... plt.ylabel('potassium')
Out[7]: <matplotlib.text.Text at 0x14346584668>

This will produce the following plot:

However, this plot is not very informative, because all data points have the same color. What we really want is for each data point to be colored according to the drug that was prescribed. For this to work, we need to somehow convert the drug labels, A through D, into numerical values. A nice trick is to use the ASCII value of a character.

In Python, this value is accessible via the ord function. For example, the A character has a value of 65 (that is, ord('A') == 65), B has 66, C has 67, and D has 68. Hence, we can turn the characters, A through D, into integers between 0 and 3 by calling ord on them and subtracting 65 from each ASCII value. We do this for every element in the dataset like we did earlier, using a list comprehension:

In [8]: target = [ord(t) - 65 for t in target]
... target
Out[8]: [0, 3, 1, 2, 3, 2, 0, 1, 3, 2, 1, 2, 1, 3, 0, 2, 1, 3, 0]

We can then pass these integer values to Matplotlib's scatter function, which will know to choose different colors for these different color labels (c=target in the following code). Let's also increase the size of the dots (s=100 in the following code) and label our axes, so that we know what we are looking at:

In [9]: plt.subplot(221)
... plt.scatter(sodium, potassium, c=target, s=100)
... plt.xlabel('sodium (Na)')
... plt.ylabel('potassium (K)')
... plt.subplot(222)
... plt.scatter(age, potassium, c=target, s=100)
... plt.xlabel('age')
... plt.ylabel('potassium (K)')
... plt.subplot(223)
... plt.scatter(age, sodium, c=target, s=100)
... plt.xlabel('age')
... plt.ylabel('sodium (Na)')
Out[9]: <matplotlib.text.Text at 0x1b36a669e48>

The preceding code will produce a graph with four subplots in a 2 x 2 grid, of which the first three subplots show different slices of the dataset, and every data point is colored according to its target label (that is, the prescribed drug):

What do these plots tell us, other than that the dataset is kind of messy? Can you spot any apparent relationship between feature values and target labels?

There are some interesting observations we can make. For example, from the first and third subplot, we can see that the light blue points seem to be clustered around high sodium levels. Similarly, all red points seem to have both low sodium and low potassium levels. The rest is less clear. So, let's see how a decision tree would solve the problem.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset