Python Machine Learning Blueprints
上QQ阅读APP看书,第一时间看更新

The matplotlib library

The first library we'll take a look at is matplotlib. The matplotlib library is the center of the Python plotting library universe. Originally created to emulate the plotting functionality of MATLAB, it grew into a fully-featured library in its own right with an enormous range of functionality. If you have not come from a MATLAB background, it can be hard to understand how all the pieces work together to create the graphs you see. I'll do my best to break down the pieces into logical components so you can get up to speed quickly. But before diving into matplotlib in full, let's set up our Jupyter Notebook to allow us to see our graphs inline. To do this, add the following lines to your import statements:

import matplotlib.pyplot as plt 
plt.style.use('ggplot') 
%matplotlib inline 

The first line imports matplotlib, the second line sets the styling to approximate R's ggplot library (requires matplotlib 1.41 or greater), and the last line sets the plots so that they are visible within the notebook.

Now, let's generate our first graph using our iris dataset:

fig, ax = plt.subplots(figsize=(6,4)) 
ax.hist(df['petal width (cm)'], color='black'); 
ax.set_ylabel('Count', fontsize=12) 
ax.set_xlabel('Width', fontsize=12) 
plt.title('Iris Petal Width', fontsize=14, y=1.01) 

The preceding code generates the following output:

There is a lot going on even in this simple example, but we'll break it down line by line. The first line creates a single subplot with a width of 6 inches and a height of 4 inches. We then plot a histogram of the petal width from our iris DataFrame by calling .hist() and passing in our data. We also set the bar color to black here. The next two lines place labels on our y and x axes, respectively, and the final line sets the title for our graph. We tweak the title's y position relative to the top of the graph with the y parameter, and increase the font size slightly over the default. This gives us a nice histogram of our petal width data. Let's now expand on that, and generate histograms for each column of our iris dataset:

fig, ax = plt.subplots(2,2, figsize=(6,4)) 
 
ax[0][0].hist(df['petal width (cm)'], color='black'); 
ax[0][0].set_ylabel('Count', fontsize=12) 
ax[0][0].set_xlabel('Width', fontsize=12) 
ax[0][0].set_title('Iris Petal Width', fontsize=14, y=1.01) 
 
ax[0][1].hist(df['petal length (cm)'], color='black'); 
ax[0][1].set_ylabel('Count', fontsize=12) 
ax[0][1].set_xlabel('Length', fontsize=12) 
ax[0][1].set_title('Iris Petal Length', fontsize=14, y=1.01) 
 
ax[1][0].hist(df['sepal width (cm)'], color='black'); 
ax[1][0].set_ylabel('Count', fontsize=12) 
ax[1][0].set_xlabel('Width', fontsize=12) 
ax[1][0].set_title('Iris Sepal Width', fontsize=14, y=1.01) 
 
ax[1][1].hist(df['sepal length (cm)'], color='black'); 
ax[1][1].set_ylabel('Count', fontsize=12) 
ax[1][1].set_xlabel('Length', fontsize=12) 
ax[1][1].set_title('Iris Sepal Length', fontsize=14, y=1.01) 
 
plt.tight_layout()  
 

The output for the preceding code is shown in the following diagram:

Obviously, this is not the most efficient way to code this, but it is useful for demonstrating how matplotlib works. Notice that instead of the single subplot object, ax, as we had in the first example, we now have four subplots, which are accessed through what is now the ax array. A new addition to the code is the call to plt.tight_layout(); this function will nicely auto-space your subplots to avoid crowding.

Let's now take a look at a few other types of plots available in matplotlib. One useful plot is a scatterplot. Here, we will plot the petal width against the petal length:

fig, ax = plt.subplots(figsize=(6,6)) 
ax.scatter(df['petal width (cm)'],df['petal length (cm)'],                      color='green') 
ax.set_xlabel('Petal Width') 
ax.set_ylabel('Petal Length') 
ax.set_title('Petal Scatterplot') 

The preceding code generates the following output:

As before, we could add in multiple subplots to examine each facet.

Another plot we could examine is a simple line plot. Here, we will look at a plot of the petal length:

fig, ax = plt.subplots(figsize=(6,6)) 
ax.plot(df['petal length (cm)'], color='blue') 
ax.set_xlabel('Specimen Number') 
ax.set_ylabel('Petal Length') 
ax.set_title('Petal Length Plot') 

The preceding code generates the following output:

We can already begin to see, based on this simple line plot, that there are distinctive clusters of lengths for each species—remember our sample dataset had 50 ordered examples of each type. This tells us that petal length is likely to be a useful feature to discriminate between the species if we were to build a classifier.

Let's look at one final type of chart from the matplotlib library, the bar chart. This is perhaps one of the more common charts you'll see. Here, we'll plot a bar chart for the mean of each feature for the three species of irises, and to make it more interesting, we'll make it a stacked bar chart with a number of additional matplotlib features:

import numpy as np
fig, ax = plt.subplots(figsize=(6,6))
bar_width = .8
labels = [x for x in df.columns if 'length' in x or 'width' in x]
set_y = [df[df['species']==0][x].mean() for x in labels]
ver_y = [df[df['species']==1][x].mean() for x in labels]
vir_y = [df[df['species']==2][x].mean() for x in labels]
x = np.arange(len(labels))
ax.bar(x, set_y, bar_width, color='black')
ax.bar(x, ver_y, bar_width, bottom=set_y, color='darkgrey')
ax.bar(x, vir_y, bar_width, bottom=[i+j for i,j in zip(set_y, ver_y)], color='white')
ax.set_xticks(x + (bar_width/2))
ax.set_xticklabels(labels, rotation=-70, fontsize=12);
ax.set_title('Mean Feature Measurement By Species', y=1.01)
ax.legend(['Setosa','Versicolor','Virginica'])

The output for the preceding snippet is given here:

To generate the bar chart, we need to pass the x and y values into the .bar() function. In this case, the x values will just be an array of the length of the features we are interested in—four here, or one for each column in our DataFrame. The np.arange() function is an easy way to generate this, but we could nearly as easily input this array manually. Since we don't want the x axis to display this as 1 through 4, we call the .set_xticklabels() function and pass in the column names we wish to display. To line up the x labels properly, we also need to adjust the spacing of the labels. This is why we set the xticks to x plus half the size of the bar_width, which we also set earlier at 0.8. The y values come from taking the mean of each feature for each species. We then plot each by calling .bar(). It is important to note that we pass in a bottom parameter for each series, which sets the minimum y point and the maximum y point of the series below it. This creates the stacked bars. And finally, we add a legend, which describes each series. The names are inserted into the legend list in order of the placement of the bars from top to bottom.