5. Next-level Data Visualization#

Learning Objectives

  • Describe Matplotlib

  • Describe and identify the components of a Matplotlib Figure

  • Describe a Matplotlib Axes

  • Explain the difference between Matplotlib’s explicit and implicit interfaces

  • Create Matplotlib Figures and Axes

  • Modify Matplotlib plots with Figure and Axes methods

  • Describe Matplotlib’s coding styles

  • Choose appropriate visualizations based on data type(s)

  • List the steps to making a well-designed visualization

  • Assess what aspects of a visualization need improvement

  • Use Matplotlib to customize visualizations made with other packages

  • Create well-designed data visualizations

This chapter is about customizing data visualizations in Python. Seaborn, plotnine, pandas, and many other packages rely on the Matplotlib package to create visualizations. Thus the first section delves into Matplotlib’s interfaces and computational model: how to reason about and modify visualizations created with the package. The second section puts the ideas from the first into practice by presenting a selection of case studies with real data, where creating a great visualization depends on customizing the default outputs from Seaborn with some Matplotlib code.

5.1. Prerequisites#

This chapter assumes you are already familiar with Python, pandas, and at least one way of making data visualizations. In particular, you should be comfortable working with DataFrames and creating scatter, line, and bar plots from data. DataLab’s Python Basics Reader and its accompanying workshop provide a suitable introduction to these topics.

To follow along, you’ll need the following software versions (or newer) installed on your computer:

One way to install all of these at once is to install the Anaconda Python distribution. Chapter 4 provides additional details about Anaconda and the conda package manager.

5.2. Thinking in Matplotlib#

Matplotlib is a relatively low-level visualization package, which means you can use it to draw almost anything, but that creating and fine-tuning common statistical plots usually requires more code than with other packages. The package has extensive documentation, including tutorials and cheat sheets. While the size of Matplotlib’s programming interface can be daunting, there’s an ongoing effort to make it more user-friendly and make the documentation more approachable, and there have already been big improvements in the last three years.

Note

This section is loosely based on Matplotlib’s Quick Start Guide.

5.2.1. Introduction#

../_images/lter_penguins.png

Fig. 5.1 Artwork by @allison_horst.#

As a way to learn the fundamentals of Matplotlib, you’ll recreate the scatter plot (without the regression lines) in Fig. 5.2, which shows flipper length versus bill length for hundreds of individual penguins from three different species: Adélie, Chinstrap, and Gentoo. Both the plot and data come from the Palmer Penguins data set, which was collected by Dr. Kristen Gorman at Palmer Station, Antarctica, and packaged for public use by Alison Horst.

Important

You can download a version of the Palmer Penguins data set that we’ve prepared for Python HERE.

../_images/flipper-bill.png

Fig. 5.2 A scatter plot from the Palmer Penguins data set.#

After downloading the data set, use pandas to read it and display some summary information:

import pandas as pd

penguins = pd.read_parquet("data/penguins.parquet")
penguins.head()
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007
3 Adelie Torgersen NaN NaN NaN NaN NaN 2007
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 female 2007
penguins.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 8 columns):
 #   Column             Non-Null Count  Dtype   
---  ------             --------------  -----   
 0   species            344 non-null    category
 1   island             344 non-null    category
 2   bill_length_mm     342 non-null    float64 
 3   bill_depth_mm      342 non-null    float64 
 4   flipper_length_mm  342 non-null    float64 
 5   body_mass_g        342 non-null    float64 
 6   sex                333 non-null    category
 7   year               344 non-null    int32   
dtypes: category(3), float64(4), int32(1)
memory usage: 13.6 KB

Matplotlib’s primary interface for creating plots interactively is PyPlot (matplotlib.pyplot). The convention is to import PyPlot as plt:

import matplotlib.pyplot as plt

PyPlot’s plt.scatter function creates a scatter plot. The first and second argument set the x and y axis, respectively. For the plot in Fig. 5.2, penguin flipper length should be on the x-axis and bill length should be on the y-axis. Use the plt.scatter function to create the plot:

plt.scatter(penguins["flipper_length_mm"], penguins["bill_length_mm"])
<matplotlib.collections.PathCollection at 0x7f2bd0d02650>
../_images/70b707f6c6fc246f546dcc41ef3f79f39dbed538608a5a530cb6d2d47e05e1e3.png

This austere scatter plot is typical of Matplotlib functions: if you want fancy formatting, you have to set more parameters and make more function calls.

Note

If you’re using Jupyter, notice that in addition to the plot, there’s also some output about a matplotlib.collections.PathCollection. This happens because most Matplotlib functions and methods return a result—typically the object you added to the plot—and Jupyter prints the last result in each code cell.

The output is harmless and safe to ignore, but if it bothers you, you can add a call to plt.show at the end of each cell where you make a plot.

Tip

Most Matplotlib plotting functions provide an alternative interface through the data parameter for string-indexed data such as data frames and dicts. For instance, if you set data = penguins in the plt.scatter function, then you can use column names to set the x and y axis:

plt.scatter("flipper_length_mm", "bill_length_mm", data = penguins)

How much formatting a plot requires depends on the intended audience, but even if you’re only making the plot for yourself, you might want to put some axis labels on it to remind you of what it’s showing. You might also want to make the points partially transparent so that you can see whether any points are hidden under others and use point styles to incorporate additional information from the data set. If you’re making a plot to share with others, you might also want to add a title and go through DataLab’s Data Visualization Guidelines.

Eventually you’ll customize the penguin scatter plot, but first you need to know a little bit more about how Matplotlib works.

5.2.2. Figures & Axes#

Matplotlib uses specific terms for the components of a visualization which can be confusing if you aren’t aware of them. The most important terms are:

  • Artist: any Matplotlib object, including the other items on this list.

  • Figure: an entire visualization. A Figure can contain one or more plots. You’ll typically interact with Figures when you want to set margins on, display, or save a visualization.

  • Axes: a single plot. An Axes can contain many other Artists. You’ll typically interact with Axes when you want to add additional data or customize the components (styles, colors, and so on) of a plot.

Important

Despite the name, an Axes object represents an entire plot (within a larger visualization), not the axes on a plot. The represented plot might not even have visible axes.

Each axis on a plot is represented by a separate Axis object (with an “i”).

Note

In this reader, we use monospace (Axes) or title-case (Axes) to refer to Python classes and lowercase (axes) for plain English.

Thus in Matplotlib terms, a visualization consists of a Figure with at least one Axes. Both are Artists, and an Axes usually contains additional Artists for lines, points, axes, legends, and so on. Fig. 5.3 identifies Matplotlib terms on an actual visualization. Keep it or the cheat sheet at hand as you’re learning the package.

../_images/anatomy.png

Fig. 5.3 The anatomy of a Matplotlib figure.#

The plt.subplots function creates a Figure with one or more Axes arranged in a grid, so it’s the starting point for most visualizations. The function returns a tuple with the Figure and the Axes. By default, it only creates one Axes:

fig, ax = plt.subplots()
../_images/9787f8a2451f2e7e36145ae170b7a0df1c349c31116bdd2985cc12a5f053b850.png

Warning

Matplotlib also has a plt.subplot function (without the final “s”). It adds an Axes to an existing Figure rather than creating a new Figure and Axes. To create a new visualization, the plt.subplots function (with the final “s”) is the one you want.

Once you’ve created a Figure and Axes, you can call methods on the Axes to add details. Section 5.2.1 explained how to create a scatter plot of the Palmer Penguins data set with the plt.scatter function. Another approach is to use the .scatter method on an Axes instead:

fig, ax = plt.subplots()
ax.scatter(penguins["flipper_length_mm"], penguins["bill_length_mm"])
<matplotlib.collections.PathCollection at 0x7f2bd0e0ef50>
../_images/70b707f6c6fc246f546dcc41ef3f79f39dbed538608a5a530cb6d2d47e05e1e3.png

Section 5.2.3 provides more information about these two different approaches to creating plots.

The first two arguments to plt.subplots control the number of rows and number of columns, respectively, in the grid of Axes. When the grid contains more than one Axes, the Axes are returned in a 2-dimensional Numpy array:

fig, axs = plt.subplots(2, 3)
../_images/82e761d2ebb2094489f34bb22cd8d325467724f84af0dd8a6199537d78d108b1.png
axs
array([[<Axes: >, <Axes: >, <Axes: >],
       [<Axes: >, <Axes: >, <Axes: >]], dtype=object)

You can index this array with square brackets [ ] in order to get the individual Axes. For example, axs[0, 2] gets the Axes in the first row and third column.

Tip

If you want make a visualization where some subplots span multiple grid cells, take a look at the plt.subplot_mosaic function. You can use it to concisely describe and create Figures with complex arrangements of Axes.

Notice that in the displayed Figure, there’s not much space between the components and some of them even overlap. You can tell Matplotlib to use a constraint solver to determine appropriate sizes for each subplot by setting layout = "constrained" on the Figure. The plt.subplots function accepts Figure-level keyword arguments, so one way to do this is:

fig, axs = plt.subplots(2, 3, layout = "constrained")
../_images/2300e9bfd347e09670d21d1034a67bdf260f985fade0dc58d62c96b367636ba4.png

Note

Constrained layout is not the only layout possible, and you may encounter visualizations that use tight layout instead. According to the Matplotlib developers, constrained layout produces better results in most cases.

You can control the size of an entire Figure with another setting, figsize. The argument should be a tuple with a width and height of the Figure in inches. This can also be set through plt.subplots:

fig, ax = plt.subplots(figsize = (1, 2))
../_images/80dab02de809beb8256e8bd8d9f21381c3086b2f4d71e7bb93205a9ed7cf09bb.png

The Matplotlib documentation has a reference page with a list of other Figure-level keyword arguments.

Tip

You can change Figure settings globally through mpl.rcParams, which is a dict-like object. See the documentation page and the guide for more details.

5.2.3. Coding Styles#

Section 5.2.1 and Section 5.2.2 explained two different ways to create a scatter plot. Matplotlib actually supports three different coding styles, two of which use the PyPlot interface:

  1. PyPlot style, where you use the PyPlot interface for everything, from plot creation to customization. This style is convenient for interactive work in the Python console because you don’t have to keep track of any intermediate variables—you just use plt.

  2. Object-oriented (OO) style, where you only use the PyPlot interface to create the Figure and Axes, and then use method calls to add detail and customize.

  3. Embedded style: where you don’t use the PyPlot interface at all. Instead, you create a Figure by explicitly calling the Figure constructor function and attaching a canvas, and then use method calls to add Axes and details. This style is convenient for non-interactive work, because it allows greater control over where Figures are drawn.

Warning

When you use either the PyPlot or OO style, it’s up to you to explicitly close Figures by calling plt.close when you’re done with them. If you create many Figures without closing them, Python may run out of memory. Matplotlib will automatically issue a warning if you have more than 20 Figures open.

When you use the embedded style, Python will automatically close Figures when they go out of scope. If you need to close a Figure before it goes out of scope, you can use the del keyword.

Warning

Many PyPlot functions and Axes methods have the same names and signatures, but not all of them. So translating between the PyPlot style and OO style sometimes requires searching the documentation even if you have one style memorized.

The Matplotlib documentation suggests using the OO style in most scenarios. For scripts that are intended to run non-interactively and save many plots to disk, the embedded style is usually more memory efficient. Switching between the OO and embedded style is relatively easy, since the only major difference is in how Figures and Axes are created. Thus we recommend the OO style and use it in all of the remaining examples.

5.2.4. Drawing on Axes#

Matplotlib provides Axes methods for drawing points, lines, text, legends, shapes, and more. In this section, you’ll use some of these methods to add detail to the scatter plot of the Palmer Penguins data set, so that it looks more like Fig. 5.2. Section 5.2.2 left off with this version of the plot:

fig, ax = plt.subplots(layout = "constrained")
ax.scatter(penguins["flipper_length_mm"], penguins["bill_length_mm"])
<matplotlib.collections.PathCollection at 0x7f2bd0812750>
../_images/d248ae372041e94e5b30efc581b8802d84aa0ec458d72d4b29e99fca82d1842c.png

In Fig. 5.2, the shape and color of each point indicates the species of penguin. Orange circles represent Adelie penguins, purple triangles represent Chinstrap penguins, and teal squares represent Gentoo penguins.

You can set the shape of the points plotted by the .scatter method with the marker parameter. Matplotlib uses shorthand strings to represent different marker types. For instance, "o" means a circle, and "^" means a point-up triangle, and "s" means a square. The documentation includes a complete list of markers.

When you call the .scatter method, you can only set one shape for the points. To plot a different shape for each group, make a separate call to .scatter for each group. For example, to plot only the points for the Chinstrap penguins:

fig, ax = plt.subplots(layout = "constrained")

obs = penguins[penguins["species"] == "Chinstrap"]
ax.scatter(obs["flipper_length_mm"], obs["bill_length_mm"], marker = "^")
<matplotlib.collections.PathCollection at 0x7f2bd055f250>
../_images/fb42d75d8a131caf900692608cbddbe1a832df0efeffcd62d5ef4aa0446f52f5.png

You can use a dictionary and a loop to concisely plot multiple groups:

fig, ax = plt.subplots(layout = "constrained")

groups = {"Adelie": "o", "Chinstrap": "^", "Gentoo": "s"}
for species, marker in groups.items():
    obs = penguins[penguins["species"] == species]
    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker) 
../_images/c36edb79a5f5d9ab06e40f8f2d50a4e9ee5bb3346b5e4c87f313581dc3b46164.png

When you call the .scatter method many times like this, Matplotlib automatically cycles through the colors in its default palette. The package also does this for most other kinds of plots.

Note

In computing contexts, colors are often described in terms of their percentages of red, green, and blue (RGB) light or channels. You can write this as a tuple of decimal numbers, each between 0.0 and 1.0. For example, (0, 0, 0) is black, (1, 1, 1) is white, (0.5, 0, 0.5) is purple, and (0, 0.5, 0.5) is teal.

For memory efficiency, each color channel is usually limited to 256 different values (8 bits per channel, or 24-bit color). In this case, you can write a color as a tuple of integers. For instance, (128, 0, 128) is purple.

You can write integer color codes concisely by using the base-16 number system, called hexadecimal. In hexadecimal, each digit represents one of 16 values (0 to 15). Conventionally, the digits are written 0-9 and a-f (10-15). Thus 8 is written as 8, 90 is written as 5a, and 255 is written as ff. Since each color channel ranges from 0 to 255, you can write each channel with two hex digits, and thus a whole color with six hex digits. Hex codes for colors are usually prefixed with #, so for example #800080 is purple.

A red-green-blue-alpha (RGBA) color has a fourth channel called alpha that controls the opacity. RGBA colors are usually 32-bit colors and thus can be written with eight hex digits.

In Matplotlib, you can specify RGB and RGBA colors as hex codes. The documentation also describes several other ways to specify colors.

To explicitly set the color of the points in a call to .scatter, use the color parameter. In Fig. 5.2, Adelie orange is approximately #fe8700, Chinstrap purple is approximately #9a27ef, and Gentoo teal is approximately #098b8b. The points are slightly transparent, so it’s also a good idea to append c0 to each color code to set 75% opacity. You can add the color codes to the groups dictionary:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color) 
../_images/5222af88c193f568ddd7179ee28c231b04f578d1ae4e421441458d6ad075c553.png

The figure also has a legend. You can add a legend to a Matplotlib Axes by calling the .legend method. Matplotlib will guess what to show in the legend based on the labeled components of the plot. You can set the label parameter in the .scatter method to label a group of points. The label will be displayed in the legend. You can also use parameters of the .legend method to control the legend directly. For instance, title sets the title and frameon controls whether the legend is displayed in a box. After adding a legend, the code becomes:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
<matplotlib.legend.Legend at 0x7f2bd11b3810>
../_images/69be23c8c8078836aef528954a45333a9c6e3b96d9be530ebd2f463d5eb230d5.png

Fig. 5.2 also has grid lines. The Axes method .grid controls the display of grid lines. The first argument is a Boolean value that sets whether the grid lines are visible. An Axes can have major or minor grid lines, which correspond to the major (labeled) and minor (unlabeled) ticks on each axis. The which parameter controls which grid lines are affected by the .grid method. It’s necessary to also call the .set_axisbelow method if you want the grid to appear behind other plot components rather than in front of them. So to add major and minor grid lines:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
ax.set_axisbelow(True)
ax.grid(True, which = "both")
../_images/8e5853da83132fea31f93d6fbb9372d88d74684d3f2f18622aeb3a8df4a1aaa4.png

The common theme of the .scatter, .legend, and .grid methods is that all of them draw additional components on an Axes. Some of the components you might want to draw and the associated methods are:

  • Lines with .plot or .add_line

  • Points with .plot or .scatter

  • Patches, which includes geometric shapes such as rectangles, with .add_patch

  • Text with .text

  • Annotations with .annotate

  • Legends with .legend

  • Images with .imshow

You can learn more about drawing these components from Matplotlib’s Artists Tutorial.

5.2.5. Formatting Figures & Axes#

Section 5.2.4 explained how to draw the components of a plot, using the scatter plot of the Palmer Penguins data set as an example. The result was a version of the plot that has all of the information from Fig. 5.2, but not all of the formatting. You can control formatting such as borders, titles, and labels through Figure and Axes methods. This section demonstrates how.

Axis labels add important context to a plot, since they describe the meaning and units for each axis. If you want people to use the axes when they interpret your plot, make sure to include axis labels. You can add axis labels to the x- and y-axis of a Matplotlib Axes with the .set_xlabel and .set_ylabel method, respectively. For example, for the penguins scatter plot:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
ax.set_axisbelow(True)
ax.grid(True, which = "both")
ax.set_xlabel("Flipper length (mm)")
ax.set_ylabel("Bill length (mm)")
Text(0, 0.5, 'Bill length (mm)')
../_images/867dfd24436e12454f35594f63ce36dd7bd8eb977c422250bf63a291e9acc014.png

Tip

Matplotlib supports a subset of TeX, so it’s possible to use mathematical symbols in visualizations. Any TeX expressions you enclose in dollar signs $ will be rendered. The documentation on mathematical expressions provides more details.

Similarly, you can add a title to a Matplotlib Axes with the .set_title method. It’s particularly important to put titles on plots that will be presented without other sources of contextual information, such as captions. Each Axes can have up to three titles, positioned at the left, center, and right. The loc parameter of .set_title controls which title is set. The title can have multiple lines, as in Fig. 5.2. With the title, the penguins plot becomes:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
ax.set_axisbelow(True)
ax.grid(True, which = "both")
ax.set_xlabel("Flipper length (mm)")
ax.set_ylabel("Bill length (mm)")

title = """Flipper and bill length
Dimensions for Adelie, Chinstrap and Gentoo Penguins at Palmer Station LTER"""
ax.set_title(title, loc = "left")
Text(0.0, 1.0, 'Flipper and bill length\nDimensions for Adelie, Chinstrap and Gentoo Penguins at Palmer Station LTER')
../_images/2ba2e2ddebe9203c9a1646af83e7b0998bbf9bfa7e892ca5688c9c3a727fe737.png

Note

Three single ''' or double quotes """ marks the beginning or end of a multi-line string.

With the grid lines visible, the border around the outside of the plot just makes it look cluttered. According to Matplotlib anatomy chart (Fig. 5.3), each line in the border is a Spine. Every Axes has a string-indexed .spines attribute which contains the Spines, and each Spine has a .set_visible method that controls whether it’s visible. You can call .set_visible on all of the Spines at once by slicing the .spines attribute, as in ax.spines[:]. So to hide the spines in the penguins plot:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
ax.set_axisbelow(True)
ax.grid(True, which = "both")
ax.set_xlabel("Flipper length (mm)")
ax.set_ylabel("Bill length (mm)")

title = """Flipper and bill length
Dimensions for Adelie, Chinstrap and Gentoo Penguins at Palmer Station LTER"""
ax.set_title(title, loc = "left")
ax.spines[:].set_visible(False)
../_images/8f9e2689d5769ee7eac279cbcc1d9d995795ee61432ec72a9357faaecf28c90d.png

Now that the border is gone, the tick marks on each axis look strange, and they don’t add any information since the grid lines are there. You can hide the tick marks while keeping the labels by setting the lengths of the tick marks to 0. To do this, use the .tick_params method, which controls parameters related to tick marks, such as their length.

Another issue with the tick marks is that more of them are labeled on the y-axis than in Fig. 5.2, and it makes the y-axis look cluttered. You can fix problems with the positioning of tick marks on an axis with the .set_xticks and .set_yticks method.

This final version of the Palmer Penguins scatter plot fixes the tick marks:

fig, ax = plt.subplots(layout = "constrained")

groups = {
    "Adelie"    : ("o", "#fe8700c0"),
    "Chinstrap" : ("^", "#9a27efc0"),
    "Gentoo"    : ("s", "#098b8bc0")
}

for species, (marker, color) in groups.items():
    obs = penguins[penguins["species"] == species]

    ax.scatter(
        obs["flipper_length_mm"], obs["bill_length_mm"], marker = marker,
        color = color, label = species) 

ax.legend(title = "Penguin species", frameon = False)
ax.set_axisbelow(True)
ax.grid(True, which = "both")
ax.set_xlabel("Flipper length (mm)")
ax.set_ylabel("Bill length (mm)")

title = """Flipper and bill length
Dimensions for Adelie, Chinstrap and Gentoo Penguins at Palmer Station LTER"""
ax.set_title(title, loc = "left")
ax.spines[:].set_visible(False)

ax.tick_params(length = 0, which = "both", axis = "both")
ax.set_yticks([40, 50, 60])
ax.set_yticks(range(30, 61, 5), minor = True)
[<matplotlib.axis.YTick at 0x7f2bd0453dd0>,
 <matplotlib.axis.YTick at 0x7f2bd0219350>,
 <matplotlib.axis.YTick at 0x7f2bd051a3d0>,
 <matplotlib.axis.YTick at 0x7f2bd04e8310>,
 <matplotlib.axis.YTick at 0x7f2bd050f1d0>,
 <matplotlib.axis.YTick at 0x7f2bd0415950>,
 <matplotlib.axis.YTick at 0x7f2bd0833310>]
../_images/635ce7173245b758f078951cffa8080a2a73d6bd1f30dff9662a5fdb69ac5962.png

As you can see, it takes about 20-30 lines of code (or more) to create and customize a Matplotlib visualization. In addition, Matplotlib only has built-in plotting functions for line plots, scatter plots, bar plots, and histograms, so any other kind of visualization requires even more work. Packages like Seaborn and plotnine exist to address this. On the other hand, perhaps you can also see the incredible flexibility of Matplotlib—it can be used to draw just about anything. The remainder of this reader will mostly focus on using Matplotlib to customize or enhance visualizations created with other packages, to have both convenience and flexibility.

The formatting in this section is just the tip of the iceberg. There are many more ways to customize Matplotlib visualizations. The documentation for the Axes class (as well as the Figure class) is a good place to find more formatting methods.

Tip

Integrated development environments (IDEs) such as JupyterLab and Visual Studio Code have a code auto-complete feature. You can use auto-complete to discover, learn, and remind yourself of attributes and methods on Python objects such as Figures and Axes.

5.2.6. Saving Figures#

After creating a great data visualization, you might want to save it as an image or some other format so that you can use it in documents or share it with other people. Fortunately, Matplotlib makes saving visualizations straightforward. Every Figure has a .savefig method that saves the Figure. The first argument is file to which to save the Figure, either as a path or as an open file. By default, the file format is inferred from the file extension (if present), or else a default is used (typically PNG). So to save the Palmer Penguins scatter plot:

fig.savefig("penguins.png")

The .savefig function has parameters to control various properties of the saved visualization, such as the dots per inch (DPI). The documentation describes these parameters.

Tip

Choosing an appropriate dots per inch (DPI) value is important! Generally, DPI should be somewhere between 72 and 300, with lower values for visualizations displayed online (to minimize file size) and higher values for visualizations displayed in print (to maximize quality). Some academic journals request even higher DPIs.

5.3. Visualization Best Practices#

Besides programming skills, to create a visualization that conveys a clear message about a data set, you need to know graphic design best practices and the principles of visual perception. DataLab’s Principles of Data Visualization covers these important skills.

DataLab’s Data Visualization Guidelines is a concise reference and reminder of these skills. Bookmark or print out a copy and run through the checklist whenever you design a visualization. The case studies in the Section 5.4 show how to fix some of the issues on the checklist.

Note

Choosing colors is often one of the visualization details that trips people up. Matplotlib’s Choosing Colormaps Guide is a good source of information for what colormaps are available in Matplotlib and how to choose an appropriate one.

5.4. Case Studies#

Now that you’ve learned a little bit about how to use Matplotlib, it’s time to put it into practice. This section presents a variety of case studies in making visualizations, using a mix of Matplotlib and Seaborn.

Note

Although the case studies use Seaborn, most of the strategies shown apply to any Matplotlib-based visualization package. Plotting functions in these packages often take an Axes as input or return an Axes as output. As long as you can get access to the Axes object, you can use Matplotlib methods to customize the plot.

5.4.1. Plotting a Categorical Distribution#

Note

This case study shows how to:

  • Use Seaborn to plot the distribution of one or more categorical features

  • Use Seaborn to create faceted plots

When you start working with a new data set, the first thing you should do is explore (and clean) the features. Exploring the features means inspecting the distribution of each feature, as well as checking for relationships between features. For this case study, let’s use Seaborn and Matplotlib to explore the categorical features in the Palmer Penguins data set from Section 5.2.1.

Consider the species of the penguins the data set. Are the three species—Adélie, Chinstrap, and Gentoo—equally represented? This question is asking how the categorical species feature is distributed. Tables, bar plots, and dot plots are all good tools for examining the distribution of a single categorical feature. Let’s make a table and a bar plot.

Tip

Inspecting data through multiple methods is a good way to verify that your code and interpretations are correct.

You can use the Pandas .value_counts method to make a table:

penguins["species"].value_counts()
Adelie       152
Gentoo       124
Chinstrap     68
Name: species, dtype: int64

To make the bar plot, first import Seaborn. The conventional abbreviation for the package is sns:

import seaborn as sns

Seaborn provides two different functions for making bar plots:

  • The sns.barplot function requires two features: the category for each bar and the length of each bar. It’s best for data that’s already been grouped and aggregated. For example, you could compute the median flipper length of each penguin species and then plot the means with sns.barplot.

  • The sns.countplot function only requires one feature: a categorical array. The function automatically groups and aggregates the values by counting them. It’s best for visualizing the distribution of a categorical feature. Note that this is purely a convenience function—you could just compute the counts yourself and plot them with sns.barplot.

Use the sns.countplot function to make a bar plot of the distribution of penguin species:

sns.countplot(penguins, x = "species")
<Axes: xlabel='species', ylabel='count'>
../_images/c7850fb1ed2531dbb112d29eb82e8503145dad4ab96522542fedbd1703401616.png

The function returns a Matplotlib Axes. You can use this to customize the plot. For example, to capitalize the axis labels:

ax = sns.countplot(penguins, x = "species")
ax.set_xlabel("Species")
ax.set_ylabel("Count")
Text(0, 0.5, 'Count')
../_images/8b445f891a6c6020f6a19ba8dd917fd8edcc6ee468a21c1af26e541344c72b05.png

Most Seaborn plotting functions also have an ax parameter you can use to specify an Axes on which to plot. This makes it possible to add to existing Axes and to customize plots produced by plotting functions that don’t return an Axes object. Using this approach, the equivalent of the previous example is:

fig, ax = plt.subplots()
sns.countplot(penguins, x = "species", ax = ax)
ax.set_xlabel("Species")
ax.set_ylabel("Count")
Text(0, 0.5, 'Count')
../_images/8b445f891a6c6020f6a19ba8dd917fd8edcc6ee468a21c1af26e541344c72b05.png

Besides species, the Palmer Penguins data set also includes information about the island where each bird was observed and the biological sex of each bird. When a data set has multiple categorical features, it’s important to check how the observations are grouped within them, since unbalanced groups can bias analyses. You can use color to represent a second categorical feature in a bar plot. Seaborn uses the hue parameter to link color to a feature. So to make the bar plot show both species and sex:

fig, ax = plt.subplots()
sns.countplot(penguins, x = "species", hue = "sex", ax = ax)
ax.set_xlabel("Species")
ax.set_ylabel("Count")
Text(0, 0.5, 'Count')
../_images/0e8089f9f7bfbbc59f15380de147703a0e760791830b56b37e18d276408df8d9.png

The plot suggests the sexes are relatively well-balanced across the observed penguins, regardless of species. Of course, there’s still the island feature to consider.

A bar plot can really only summarize two features at once, and many other types of plots are also limited to just two or three features. One way to represent more categorical features is to use facets, side-by-side plots that each show data for a mutually exclusive category or combination of categories.

Most plotting packages provide convenience functions for making faceted plots, and Seaborn is no exception. The sns.FacetGrid function creates a grid of plots based on one or two categorical features. You can use the returned FacetGrid’s .map_dataframe method to call a plotting function for each facet with the appropriate subsets of the data. For example, to indicate sex by row, island by column, and use a separate bar for each species:

grid = sns.FacetGrid(penguins, row = "sex", col = "island")
grid.map_dataframe(sns.countplot, x = "species")
<seaborn.axisgrid.FacetGrid at 0x7f2bafa64150>
../_images/b19c1c84d19468f29719eb83ae1c15fe10f2ee856c1225eae1b5e4a30d591f13.png

The margin_titles parameter of sns.FacetGrid controls whether the group titles are placed in the margins. The default is False, but setting the parameter to True makes the plot easier to read:

grid = sns.FacetGrid(
    penguins, row = "sex", col = "island", margin_titles = True)
grid.map_dataframe(sns.countplot, x = "species")
<seaborn.axisgrid.FacetGrid at 0x7f2baf871910>
../_images/5f624106473893e6b54f3bb6da01edff4b73c69047fbfa074223811b94d1e609.png

You can use the FacetGrid method .set_titles to set the titles of faceted plots. The title strings are treated as templates with some values replaced automatically. For instance, {row_name} is replaced by the row’s category. Here’s how to change the titles so that they only show the categories and not the feature names:

grid = sns.FacetGrid(
    penguins, row = "sex", col = "island", margin_titles = True)
grid.map_dataframe(sns.countplot, x = "species")
grid.set_titles(row_template = "{row_name}", col_template = "{col_name}")
<seaborn.axisgrid.FacetGrid at 0x7f2baf5a2790>
../_images/352a2429ff2b217ba33f5f44dc8a8b7d417d5868c6fabbc740af768c9a84767d.png

Tip

If you want to change the names of the categories in a feature, do that as a separate step before making visualizations. The Pandas documentation about Categorical data provides details about how to create categorical features and set category names.

Similarly, you can use the .set_axis_labels to set the axis labels:

grid = sns.FacetGrid(
    penguins, row = "sex", col = "island", margin_titles = True)
grid.map_dataframe(sns.countplot, x = "species")
grid.set_titles(row_template = "{row_name}", col_template = "{col_name}")
grid.set_axis_labels(x_var = "Species", y_var = "Count")
<seaborn.axisgrid.FacetGrid at 0x7f2baf547f90>
../_images/ed7b15cd4b69c24fe2a946ccf0118e5f2b1eee3d34abe33248947e77f0a732b1.png

Seaborn’s methods for creating and customizing faceted plots are convenient, and you can read more about them in the FacetGrid documentation. Occasionally, you may need to work with the underlying Matplotlib Figure and Axes. Fortunately, FacetGrid objects provide access to these through the .figure and .axes attributes.

Note

This case study focuses on categorical features and the next focuses on continuous features. So what should you do if you have a discrete feature?

Discrete features share properties of categorical features (both are enumerable) and properties of continuous features (both are quantitative). This usually means you can choose whether to treat discrete features as categorical or continuous. Categorical methods tend to be more appropriate for discrete features that take relatively few distinct values, and continuous methods tend to be more appropriate for ones that don’t.

5.4.2. Plotting a Continuous Distribution#

Note

This case study shows how to:

  • Use Seaborn to plot the distribution of one or more continuous features

Section 5.4.1 shows how to visualize the distribution of a categorical feature. For this case study, let’s examine one of the continuous features in the Palmer Penguins data set from Section 5.2.1.

There are many different ways to visualize continuous features, such as histograms, density plots, box plots, violin plots, and empirical cumulative distribution function plots. The Seaborn documentation provides a detailed guide to visualizing distributions.

Warning

The sns.displot function can create many different kinds of distribution plots, but it operates at the Figure level and does not accept or return an Axes, making it difficult to customize the result. For this reason, we recommend against using sns.displot.

If you decide to use sns.displot anyways, be careful not to confuse it with sns.distplot (with two “t”s). Both make distribution plots, but the former uses a new interface that’s more consistent with other Seaborn functions. You should use the new interface (sns.displot) in any new code you write, as support for the old interface may eventually end.

Let’s visualize the distribution of the birds’ body mass, which is recorded in the body_mass_g feature. Some of Seaborn’s functions for plotting continuous distributions are:

  • sns.histplot to make a histogram

  • sns.kdeplot to make a density (or “kernel density estimator”) plot

  • sns.boxplot to make a box plot

  • sns.boxenplot to make a boxen (or “letter value”) plot

  • sns.ecdfplot to make an empirical cumulative distribution function plot

  • sns.violinplot to make a violin plot

For example, to make a histogram:

sns.histplot(penguins, x = "body_mass_g")
<Axes: xlabel='body_mass_g', ylabel='Count'>
../_images/b3ecaa80c43f19195e3ec12041681e4ee9f4f6ccc4968b2c0cabc118e9a33630.png

And to make a density plot:

sns.kdeplot(penguins, x = "body_mass_g")
<Axes: xlabel='body_mass_g', ylabel='Density'>
../_images/a4de933cb0a7c188bebc653c9b34d26dbfe11b672f03b809f565b1e099c00ab2.png

Like most other Seaborn plotting functions, these functions have an ax parameter for the Axes on which to plot, and also returns an Axes. So to make the x-axis label easier to read:

fig, ax = plt.subplots()
sns.kdeplot(penguins, x = "body_mass_g", ax = ax)
ax.set_xlabel("Body Mass (g)")
Text(0.5, 0, 'Body Mass (g)')
../_images/461775219071f0ec917f2bdbab3c232ef9c16c9129aefbdb43b29c13a3bb3cb7.png

Density plots are convenient when you want to break down the distribution of a continuous feature across the categories of a categorical feature, because the plot can show a separate line for each category. For instance, to group body mass by species:

fig, ax = plt.subplots()
sns.kdeplot(penguins, x = "body_mass_g", hue = "species", ax = ax)
ax.set_xlabel("Body Mass (g)")
Text(0.5, 0, 'Body Mass (g)')
../_images/44d8b324eb11f1d79ef7fc9e944382026005c173db6c5c26e828351b9272841c.png

This makes it easy to see that Gentoo penguins typically have more mass than the other two species.

As in Section 5.4.1, you can incorporate more categorical features by faceting. To incorporate another continuous feature, it’s necessary to change plot types.

You can visualize the distributions of and relationship between two continuous features with a scatter plot (sns.scatterplot) or a smoothed scatter plot (again sns.kdeplot). The latter smooths out the points in a scatter plot to show their relative density, using a 2-dimensional generalization of the method used to estimate a density plot. For instance, to plot body weight against flipper length:

fig, ax = plt.subplots()
sns.kdeplot(
    penguins, x = "body_mass_g", y = "flipper_length_mm", hue = "species",
    ax = ax)
ax.set_xlabel("Body Mass (g)")
ax.set_ylabel("Flipper Length (mm)")
Text(0, 0.5, 'Flipper Length (mm)')
../_images/bdbe0b20fcefff30331a43d19b84a10d43c92fa68b03a6352560c9f0ceead312.png

In this plot it’s possible to see how the distributions of body mass and flipper length differ across species, as well as how body mass and flipper mass are related.

Tip

Visualizing more than two continuous features at once can be challenging. One way to get around this problem is to create a separate plot for each pair of features.

If it really is necessary to visualize more than two at once, consider converting at least one of them into a categorical feature by binning or discretizing the values. Then you can use strategies for categorical features such as varying colors and faceting.

5.4.3. Plotting a Time Series#

Note

This case study shows how to:

  • Convert temporal data to appropriate data types

  • Use Seaborn to plot a time series as a histogram or a line

  • Use multiple data sets in a single plot

  • Customize and rotate the text of tick labels

  • Move a legend on a Seaborn plot

Plotting a time series can be tricky because you first need to make sure that the temporal features have appropriate data types. For this case study, suppose you want to make a visualization that shows the relationship between flu rates in birds and humans.

People mail dead birds to the USDA and USGS, where scientists analyze the birds to find out why they died. The USDA compiles the information into a public Avian Influenza data set each year.

Important

You can download the Avian Influenza data HERE.

You can use pandas’ pd.read_csv function to read the data. Each row corresponds to one bird death. There are 8 columns with information about the date, species of bird, collection method, and location. The Date Detected column is a date, while the rest of the columns are categorical, which you can see by taking a peek at the data set:

import pandas as pd

pd.read_csv("data/hpai-wild-birds-ver2.csv", nrows = 5)
State County Date Detected HPAI Strain Bird Species WOAH Classification Sampling Method Submitting Agency
0 South Carolina Colleton 1/13/2022 EA H5N1 American wigeon Wild bird Hunter harvest NWDP
1 South Carolina Colleton 1/13/2022 EA H5N1 Blue-winged teal Wild bird Hunter harvest NWDP
2 North Carolina Hyde 1/12/2022 EA H5N1 Northern shoveler Wild bird Hunter harvest NWDP
3 North Carolina Hyde 1/20/2022 EA H5N1 American wigeon Wild bird Hunter harvest NWDP
4 North Carolina Hyde 1/20/2022 EA H5 Gadwall Wild bird Hunter harvest NWDP

To parse the dates, set the pd.read_csv function’s parse_dates parameter to a list of date columns to parse. You can also set the dtype parameter to data types for other columns, either as a dictionary with one entry per column, or a string default type. So to read the date column as dates and the rest as categories:

birds = pd.read_csv(
    "data/hpai-wild-birds-ver2.csv",
    parse_dates = ["Date Detected"], dtype = "category")

birds.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6542 entries, 0 to 6541
Data columns (total 8 columns):
 #   Column               Non-Null Count  Dtype         
---  ------               --------------  -----         
 0   State                6542 non-null   category      
 1   County               6542 non-null   category      
 2   Date Detected        6540 non-null   datetime64[ns]
 3   HPAI Strain          6541 non-null   category      
 4   Bird Species         6542 non-null   category      
 5   WOAH Classification  6542 non-null   category      
 6   Sampling Method      6542 non-null   category      
 7   Submitting Agency    6542 non-null   category      
dtypes: category(7), datetime64[ns](1)
memory usage: 144.6 KB

Tip

It’s also possible to parse dates in pandas with the pd.to_datetime function, and to convert columns to other data types with the .astype method.

Specifying appropriate data types when you read data often provides performance benefits, but doing data type conversions later gives you more flexibility to clean the data first. So which approach you should use in any given situation depends on the data set and your analysis goals.

What kind of plot is appropriate for this data set? Note that there are several dates where more than one bird was collected:

birds["Date Detected"].value_counts().head()
2022-10-25    366
2022-11-21    148
2022-09-13    129
2022-09-23    128
2022-02-16    109
Name: Date Detected, dtype: int64

While you could use a line plot to show the counts for each date, there are also many dates where the counts are zero. To see a smoothed out version of the line plot, you can use a histogram of the observations. In Seaborn, the sns.histplot function makes a histogram (and returns a Matplotlib Axes):

import seaborn as sns

ax = sns.histplot(birds, x = "Date Detected")
../_images/11a6b8f09273daa6e3e43298b4a4a95912c0ac063b43d3405e13804a7a7b4aa3.png

Depending on how wide your screen is, the x-axis tick labels may get squeezed together, making them difficult to read. This problem is especially common when plotting dates, since the labels tend to be long. One way to fix it is to rotate the labels. With Matplotlib, you can use the Axes method .set_xticks (or .set_yticks) to control properties of the ticks and tick labels. The first two arguments to the function are the tick positions and tick labels; you can get the default values with the .get_xticks and .get_xticklabels methods. The rotation parameter controls the degree of rotation. So to make the histogram with the x-axis tick labels at 45 degrees:

ax = sns.histplot(birds, x = "Date Detected")
ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation = 45)
[<matplotlib.axis.XTick at 0x7f2bae436510>,
 <matplotlib.axis.XTick at 0x7f2bae10ecd0>,
 <matplotlib.axis.XTick at 0x7f2bae183810>,
 <matplotlib.axis.XTick at 0x7f2bae18dd50>,
 <matplotlib.axis.XTick at 0x7f2bae18ff90>,
 <matplotlib.axis.XTick at 0x7f2bae192450>,
 <matplotlib.axis.XTick at 0x7f2bae650d10>,
 <matplotlib.axis.XTick at 0x7f2bae199550>,
 <matplotlib.axis.XTick at 0x7f2bae19b950>]
../_images/10c87f066a2e18639d15db34623c2cc2eac6c4699ce10584039c868faeac5c96.png

The rotation interferes with the horizontal alignment of the tick labels, so they’re positioned too far to the right. To fix this, set the horizontal alignment parameter ha to "right" and set the rotation mode parameter rotation_mode to "anchor":

ax = sns.histplot(birds, x = "Date Detected")
ax.set_xticks(
    ax.get_xticks(), ax.get_xticklabels(), rotation = 45, ha = "right",
    rotation_mode = "anchor")
[<matplotlib.axis.XTick at 0x7f2bae1aa6d0>,
 <matplotlib.axis.XTick at 0x7f2bae1cecd0>,
 <matplotlib.axis.XTick at 0x7f2bae1832d0>,
 <matplotlib.axis.XTick at 0x7f2bae03d110>,
 <matplotlib.axis.XTick at 0x7f2bae03f510>,
 <matplotlib.axis.XTick at 0x7f2bae045750>,
 <matplotlib.axis.XTick at 0x7f2baf2662d0>,
 <matplotlib.axis.XTick at 0x7f2bae04c7d0>,
 <matplotlib.axis.XTick at 0x7f2bae04e950>]
../_images/78c8e7ef4ead46b4839c917036dfa5b52fe7db68a43c7363cb3d3221019e3c70.png

The data set includes information about how the birds were collected in the Sampling Method column. You can incorporate this information in the plot by setting the hue parameter in the call to sns.histplot. In this case, stacked bars are a good choice since the focus is on the total counts for each date range rather than the individual sampling methods. You can make the bars stack by setting the multiple parameter to "stack". The code becomes:

ax = sns.histplot(
    birds, x = "Date Detected", hue = "Sampling Method", multiple = "stack")
ax.set_xticks(
    ax.get_xticks(), ax.get_xticklabels(), rotation = 45, ha = "right",
    rotation_mode = "anchor")
[<matplotlib.axis.XTick at 0x7f2bae05b2d0>,
 <matplotlib.axis.XTick at 0x7f2bae080790>,
 <matplotlib.axis.XTick at 0x7f2badfa54d0>,
 <matplotlib.axis.XTick at 0x7f2badfa5750>,
 <matplotlib.axis.XTick at 0x7f2bae0bc950>,
 <matplotlib.axis.XTick at 0x7f2bae0bee50>,
 <matplotlib.axis.XTick at 0x7f2badfa9190>,
 <matplotlib.axis.XTick at 0x7f2badfa9dd0>,
 <matplotlib.axis.XTick at 0x7f2badfb4150>]
../_images/4cbeec86cb05929adf0972379b7144cf1e75b1b36280eaa784438dd60def38b9.png

Seaborn automatically creates a legend, but the placement on top of the plot isn’t ideal. While you could create the legend manually with Matplotlib instead, in this case it’s easier to use Seaborn’s sns.move_legend function to move the legend:

ax = sns.histplot(
    birds, x = "Date Detected", hue = "Sampling Method", multiple = "stack")
ax.set_xticks(
    ax.get_xticks(), ax.get_xticklabels(), rotation = 45, ha = "right",
    rotation_mode = "anchor")
sns.move_legend(ax, loc = "center left", bbox_to_anchor = (1, 0.5))
../_images/0380c07de5c300deb76bbacaf3666a4288fd0a2cc11743af9f474698af632a5c.png

The plot looks good, so now let’s add data about human flu rates. The CDC collects data about flu hospitalizations across 13 states. The data is publicly available as the FluServ-NET data set.

Important

You can download the FluServ-NET data HERE.

In the FluServ-NET data set, each row corresponds to a single combination of week, year, age, sex, and race. There are 10 columns, which are a mix of data types. The CSV file the CDC provides contains extra text at the beginning and end, and the column names contain spaces and other problematic characters. You can read and clean up the data set with this code:

flu = pd.read_csv(
    "data/FluSurveillance_Custom_Download_Data.csv", skiprows = 2)
flu.head()
CATCHMENT NETWORK YEAR MMWR-YEAR MMWR-WEEK AGE CATEGORY SEX CATEGORY RACE CATEGORY CUMULATIVE RATE WEEKLY RATE
0 Entire Network FluSurv-NET 2021-22 2021.0 40.0 Overall Overall Overall 0.0 0.0
1 Entire Network FluSurv-NET 2021-22 2021.0 40.0 Overall Overall White 0.0 0.0
2 Entire Network FluSurv-NET 2021-22 2021.0 40.0 Overall Overall Black 0.0 0.0
3 Entire Network FluSurv-NET 2021-22 2021.0 40.0 Overall Overall Hispanic/Latino 0.0 0.0
4 Entire Network FluSurv-NET 2021-22 2021.0 40.0 Overall Overall Asian/Pacific Islander 0.0 0.0
# Fix the column names.
flu.columns = flu.columns.str.lower().str.strip()
flu.columns = flu.columns.str.replace("[ -]+", "_", regex = True)

# Remove the text at the end.
flu = flu.query("catchment == 'Entire Network'")

For the plot, you’ll only need the "Overall" age, sex, and race categories:

flu = flu.query(
    "age_category == 'Overall' and sex_category == 'Overall' and "
    "race_category == 'Overall'")

You can convert each year and week pair into a date by concatenating them and then parsing with the pd.to_datetime function:

dates = 1000 * flu["mmwr_year"] + 10 * flu["mmwr_week"]
dates = dates.astype(int).astype(str)
flu.loc[:, "date"] = pd.to_datetime(dates, format = "%Y%W%w")

Finally, you can add a line for the weekly flu rates in the weekly_rate column with the .plot method. The weekly flu rates are measured in hospitalizations per 100,000 people and typically range from 0 to 100. Multiplying by 100 to convert to hospitalizations per 10 million people makes the range 0 to 1000, which is a nice match for the y-axis already on the plot. The code becomes:

ax = sns.histplot(
    birds, x = "Date Detected", hue = "Sampling Method", multiple = "stack")
ax.set_xticks(
    ax.get_xticks(), ax.get_xticklabels(), rotation = 45, ha = "right",
    rotation_mode = "anchor")
sns.move_legend(ax, loc = "center left", bbox_to_anchor = (1, 0.5))

ax.plot(flu["date"], flu["weekly_rate"] * 100, color = "#000000")
[<matplotlib.lines.Line2D at 0x7f2bade06c10>]
../_images/955a119c3634e4726d860c5edb87562e68caf05c8e91c7e1a3183f42b16d651f.png

Finally, add a title and clarify what the y-axis means:

ax = sns.histplot(
    birds, x = "Date Detected", hue = "Sampling Method", multiple = "stack")
ax.set_xticks(
    ax.get_xticks(), ax.get_xticklabels(), rotation = 45, ha = "right",
    rotation_mode = "anchor")
sns.move_legend(ax, loc = "center left", bbox_to_anchor = (1, 0.5))

ax.plot(flu["date"], flu["weekly_rate"] * 100, color = "#000000")

ax.set_title("Flu Rates for Birds and Humans")
ax.set_xlabel("Date")
ax.set_ylabel("Reported bird deaths\nHospitalizations per 10 million people")
Text(0, 0.5, 'Reported bird deaths\nHospitalizations per 10 million people')
../_images/1557ff592edd7cb4d2700e669298cf0bd0391d98df9a7082097cd4e3cdfe26d8.png

5.4.4. Plotting a Function#

Note

This case study shows how to:

  • Plot a function

  • Fill the area under a function

  • Set custom colors for lines, fills, and annotations

  • Hide the axes of a plot

  • Add annotations to a plot

Plotting a curve or function can make it much easier to understand and explain its behavior. For this case study, suppose you want to make a visualization that shows how the area of overlap between two different probability density functions is a measure of how similar they are. They should have some overlap, but not too much. The color palette should be UC blue (#1a3f68) and gold (#e6c257) to match the rest of your presentation.

Note

A probability density function is a function that shows how likely different outcomes are for a continuous probability distribution. The probability of outcomes in any given interval is the area under the curve. For example, if the area under the curve between 0 and 1 is 0.2, then there’s a 20% chance of an outcome between 0 and 1. The total area under the curve is always 1.

Since the total area under a probability density function is always 1, the area of overlap between two density functions is always between 0 and 1. As a result, the area of overlap is a convenient measure of similarity.

To plot a function, first evaluate it at many points over the interval of interest. Let’s start with 1,000 points in the interval \((-20, 20)\). You can use NumPy’s np.linspace function to compute evenly spaced points:

import numpy as np

x = np.linspace(-20, 20, 1000)

The SciPy package provides probability density functions for common distributions. Let’s use a normal distribution and a gamma distribution. Normal distributions are widely known, while gamma distributions are visually different but will still have some overlap. You can import the probability density functions for these distributions from scipy.stats. The gamma distribution requires a shape parameter; let’s use 10. Evaluate functions the at the x-coordinates:

from scipy.stats import norm, gamma

y1 = norm.pdf(x)
y2 = gamma.pdf(x, 10)

You can use Matplotlib to plot the \((x, y)\) coordinates as lines. The .plot method creates a line plot by default. Set the linestyle parameter on one of the lines to a dash (--) so that the lines are distinct:

fig, ax = plt.subplots()
ax.plot(x, y1)
ax.plot(x, y2, linestyle = "--")
[<matplotlib.lines.Line2D at 0x7f2bada92110>]
../_images/58226a13e7833f764c681554d86c3a46a878c1c87cb7681d3707a6d98233654f.png

The two distributions don’t overlap much, but you can fix that by adjusting their parameters. SciPy provides location (loc) and scale (scale) parameters to control where distributions are located and how much they spread out. The defaults are 0 and 1, respectively. To get some overlap, let’s increase the scale of the normal distribution to 4 and shift the location of the gamma distribution over to -2. Then the code for the plot becomes:

x = np.linspace(-20, 20, 1000)
y1 = norm.pdf(x, scale = 4)
y2 = gamma.pdf(x, 10, loc = -2)

fig, ax = plt.subplots()
ax.plot(x, y1)
ax.plot(x, y2, linestyle = "--")
[<matplotlib.lines.Line2D at 0x7f2bad919b10>]
../_images/cad1e0b2e6f53932a34b0dad910bc23a7902295a6f99d95c4ff5a82f406649b0.png

Let’s emphasize the overlap by filling it with a color. Axes have a .fill_between method that fills the area underneath a curve. In this case, the area of overlap is the area underneath whichever function happens to be smaller at a given point—the minimum of the two functions. NumPy’s np.fmin function computes the element-wise minimum of two arrays. So the code becomes:

x = np.linspace(-20, 20, 1000)
y1 = norm.pdf(x, scale = 4)
y2 = gamma.pdf(x, 10, loc = -2)

fig, ax = plt.subplots()
ax.plot(x, y1)
ax.plot(x, y2, linestyle = "--")

# Fill area of overlap.
ymin = np.fmin(y1, y2)
ax.fill_between(x, ymin)
<matplotlib.collections.PolyCollection at 0x7f2bad9ac790>
../_images/3abaa02c3b99da233158d5e2a15877a802fe0929bc308c468b9b5325d241f49b.png

This diagram is meant to make it easier to explain area of overlap, which is shown by the curves and fill. The axis ticks and labels don’t aid understanding, so let’s remove them. You can use the Axes .axis method with the argument "off" to do so. Let’s adjust the x interval to \((-15, 25)\) to center the curves, change the color to UC blue for the lines, and change the color to UC gold for the fill:

x = np.linspace(-15, 25, 1000)
y1 = norm.pdf(x, scale = 4)
y2 = gamma.pdf(x, 10, loc = -2)

fig, ax = plt.subplots()
ax.plot(x, y1, color = "#1a3f68")
ax.plot(x, y2, color = "#1a3f68", linestyle = "--")

# Fill area of overlap.
ymin = np.fmin(y1, y2)
ax.fill_between(x, ymin, color = "#e6c257")

ax.axis("off")
(-17.0, 27.0, -0.006587663025137597, 0.13834092352788954)
../_images/0cb6070e20beded1f617c16cdf1d8a227b117c8ed380179c166c55c1537264b1.png

Finally, let’s add a text label and arrow to emphasize the area of overlap. You can use the .annotate method to add an annotation to an Axes. Annotations are flexible: they can be lines, arrows, text, images, or some combination of these.

For making a text label and arrow with .annotate, the following parameters are important:

  • text: the text of the label, as a string

  • xy: the coordinates of the arrow’s tip, as a tuple

  • xytext: the coordinates of the text and the arrow’s tail, as a tuple

  • xycoords: a string that specifies the coordinate system; for example, the default "data" uses the data’s coordinate system, while "axes fraction" uses 0 to 1 along each axis, so \((0.5, 0.5)\) is the center; see the documentation for all possible options

  • color: the color of the text (but not the arrow), as a string

  • arrowprops: a dictionary with properties for the arrow, such as:

    • arrowstyle: a string that specifies the arrow’s head and line style

    • color: the color of the arrow, as a string

  • fontsize: the point size of the text, as an integer

With the annotation, the plot becomes:

x = np.linspace(-15, 25, 1000)
y1 = norm.pdf(x, scale = 4)
y2 = gamma.pdf(x, 10, loc = -2)

fig, ax = plt.subplots()
ax.plot(x, y1, color = "#1a3f68")
ax.plot(x, y2, color = "#1a3f68", linestyle = "--")

# Fill area of overlap.
ymin = np.fmin(y1, y2)
ax.fill_between(x, ymin, color = "#e6c257")

ax.axis("off")

ax.annotate(
    text = "overlap", color = "#1a3f68", fontsize = 20,
    xy = (0.5, 0.15), xytext = (0.75, 0.75), xycoords = "axes fraction",
    arrowprops = {"color": "#1a3f68", "arrowstyle": "->"})
Text(0.75, 0.75, 'overlap')
../_images/e55d0bc67842993dae2dc69ed859c054fc1c5c5c6d74202d6e0548c67d9828be.png

5.4.5. Plotting an Image#

Note

This case study shows how to:

  • Plot an image

  • Plot other shapes on top of an image, such as rectangles

Image data sets are common in many disciplines. For example, a radiologist might have a collection of x-ray images. For these data sets, being able to display and annotate images is extremely helpful. In this case study, suppose you want to display the picture of American Bison in Fig. 5.4 and add a box around each animal.

Important

You can download the American Bison image by right-clicking on the image in Fig. 5.4 and selecting Save Image As... from the context menu.

../_images/bison.png

Fig. 5.4 American Bison#

The first step is to read the image into Python. The Pillow package provides functions for reading, editing, and writing images. The package is a fork (an alternative version) of the older Python Imaging Library (PIL), and still uses the PIL name in Python code. Import the Image class from the package:

from PIL import Image

You can use the Image method .open to read the image:

img = Image.open("data/bison.png")

Note

In computing contexts, an image is typically represented by an array of color values. Grayscale images can be represented by a 2-dimensional matrix, but color images require an extra dimension for the red, green, and blue channels.

The .imshow method expects an image with dimensions:

\[ \textrm{width} \times \textrm{height} \times \textrm{channel} \]

The channel dimension should have 3 elements: red, green, and blue, in that order. The .imshow method also works with grayscale images that only have 2 dimensions.

If you try to plot an image with .imshow and Python raises an error or the image looks strange (especially the colors), the first thing to check is whether the image has the right dimensions and color channels. You can usually fix any problems with the functions NumPy and Pillow provide to reshape arrays and convert between different color spaces.

The Axes method .imshow plots an image. As usual, use plt.subplots to set up a Figure and Axes before plotting:

fig, ax = plt.subplots()
ax.imshow(img)
<matplotlib.image.AxesImage at 0x7f2badc2d790>
../_images/08387c4e4935b6c45070a50f310dfb6e3936340bdd3201049e687c8e2a425d68.png

Now let’s add a red rectangle around the bison in the foreground. In Matplotlib terminology, a 2-dimensional geometric shape is a Patch, and a Rectangle is a specific kind of Patch. To create a rectangle, import matplotlib and use the patches.Rectangle constructor function. The function’s first argument is a tuple with the top left coordinates of the rectangle. You can use separate arguments to specify the width and height. For the bison image, position the rectangle at (120, 200), make the width 140, and make the height 160:

import matplotlib as mpl

rect = mpl.patches.Rectangle((120, 200), width = 140, height = 160)

To add a Patch to a plot, use the Axes method .add_patch:

fig, ax = plt.subplots()
ax.imshow(img)
# Add rectangle.
ax.add_patch(rect)
<matplotlib.patches.Rectangle at 0x7f2bade16dd0>
../_images/31cd0f8254d84765041d5176d50e0c2a82d0974b45dcf217dffeb2e85f0d6344.png

By default, patches are filled in blue, but you can change the fill and edge colors by setting the facecolor and edgecolor arguments when you create the Patch. The special color "none" is fully transparent. So to make the rectangle have transparent fill and red edges:

rect = mpl.patches.Rectangle(
    (120, 200), width = 140, height = 160, facecolor = "none",
    edgecolor = "#FF0000")

fig, ax = plt.subplots()
ax.imshow(img)
# Add rectangle.
ax.add_patch(rect)
<matplotlib.patches.Rectangle at 0x7f2bad8f9210>
../_images/5cf35ee707c451a827ebf4a521d14361c8b5b4266e5368678502890173a8e416.png

You can create and add as many patches as needed to a plot. If you need to add lots of patches, consider whether it’s possible to use a loop. Image annotations are especially useful for displaying results from image processing and machine learning algorithms.

5.4.6. Plotting a Map#

Geospatial data are often best visualized by making a map. When you make a map, visualization best practices still apply, but there are also many additional details to consider, such as what projection to use. The GeoPandas package provides functions to read and visualize geospatial data. Like many other packages, the GeoPandas visualization functions are built on Matplotlib, so you can use the customization strategies described in this reader. No case study is provided here because working with geospatial data requires specialized knowledge that goes beyond the focus of this chapter.

Note

If you work with geospatial data frequently or need to make publication-quality maps, it may be better to use dedicated geospatial software, such as QGIS. DataLab’s Intro to Desktop GIS with QGIS workshop provides an introduction.