import pandas as pd
import numpy as np
import dask.dataframe as dd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.ticker as ticker
'default')
plt.style.use(from datetime import datetime
from typing import Tuple, Union
from collections.abc import Iterable
In my day job I’ve recently been working on visualizing time series data with matplotlib
. When I’ve wanted some sort of specific formatting or behavior, I’ve had to wade deep in the library documentation to figure out the code for the things I want to do. That’s fine, but what frustrates me is that in the back of my mind, I’m pretty sure these are all things I’ve actually done before. A few years ago in graduate school I took a course that was heavy on matplotlib
plotting, and I spent many hours figuring out bizarre techniques for getting very specific results. Unfortunately I’ve forgotten almost all of that library-specific minutiae.
This has led me to a new project – I’m creating a personal data science utilities library with modules for matplotlib
/seaborn
, pandas
, geopandas
, and scikit-learn
, each containing functions that do all the random things I’ve repeatedly built custom code solutions for in the past. The hope is that I don’t need to keep re-learning the same things and can call simple wrapper functions that I’ve written.
In this post, I’m going to cover a few matplotlib
extensions that I often need. I’ll cover:
- Formatting y-axis strings into whole numbers with commas
- Moving a legend outside of the axis
- Reusing a colormap in
seaborn
so that the same classes correspond to the same colors across various visualizations - Grid plots
I’ll be using the BART dataset that I created in my blog post on dask
.
= dd.read_parquet("../dask-data-io/data/parquet_data")
df df.head()
Date | Hour | Start | End | Riders | |
---|---|---|---|---|---|
0 | 2011-01-01 | 0 | 12TH | 16TH | 1 |
1 | 2011-01-01 | 0 | 12TH | 24TH | 3 |
2 | 2011-01-01 | 0 | 12TH | ASHB | 2 |
3 | 2011-01-01 | 0 | 12TH | BAYF | 5 |
4 | 2011-01-01 | 0 | 12TH | CIVC | 3 |
Let’s say I’m interested specifically in an aggregate time series with the number of riders that used the system, for each hour from 2011-2024. I’ll do the appropriate grouping to build that dataset and move from dask
to pandas
.
= df.groupby(['Date', 'Hour'])['Riders'].sum().compute()
df_rph = df_rph.reset_index()
df_rph df_rph.head()
Date | Hour | Riders | |
---|---|---|---|
0 | 2011-01-01 | 0 | 5174 |
1 | 2011-01-01 | 1 | 15479 |
2 | 2011-01-01 | 2 | 11055 |
3 | 2011-01-01 | 3 | 5592 |
4 | 2011-01-01 | 4 | 795 |
I’m going to be doing a lot of slicing on the date, and I find that it’s easiest to do that when the date column is a datetime index. Also, since I’ll do some analysis of months of data, I’m going to truncate the time series to the last full month.
'Date'] = pd.to_datetime(df_rph['Date'], format="%Y-%m-%d")
df_rph[= df_rph.set_index("Date")
df_rph_d = df_rph_d.loc[:"2024-04-30"]
df_rph_d df_rph_d.head()
Hour | Riders | |
---|---|---|
Date | ||
2011-01-01 | 0 | 5174 |
2011-01-01 | 1 | 15479 |
2011-01-01 | 2 | 11055 |
2011-01-01 | 3 | 5592 |
2011-01-01 | 4 | 795 |
For now I’ll sum over the hours and just examine ride totals per day in a pd.Series
= df_rph_d.resample("ME")['Riders'].sum()
month_time_series month_time_series.head()
Date
2011-01-31 8203562
2011-02-28 7933264
2011-03-31 9049039
2011-04-30 8824840
2011-05-31 8940380
Freq: ME, Name: Riders, dtype: int64
Simply plotting the series shows the impact of the pandemic, which is interesting:
month_time_series.plot()"Monthly Total BART Riders"); plt.title(
Pandemic Impact
I might want to do a slightly more dynamic plot of the pandemic’s impact, where I more directly show how this drop in ridership compared to typical ridership.
To create this more complex visual, I’ll create a suite of functions:
time_overlay_plot
, which creates an overlay plot where the x-axis represents months and each series corresponds to a year of data.show_all_xticks
, which overrides default matplotlib x-tick behavior and shows all possible tickscustom_legend
, which will make it easier to move a legend outside of the plot area
Code
def time_overlay_plot(month_time_series: pd.Series,
ax: matplotlib.axes.Axes,int,
step_size: int,
highlight_year: str,
date_formatter: float = .3) -> Tuple[matplotlib.axes.Axes, pd.Index]:
bg_alpha: """
Create an overlay plot where the x-axis represents months and each series corresponds to a year of data.
This function reads in a month-level time series with data spanning multiple years. It overlays the data for each year
on the same plot, allowing for year-over-year comparison. One year can be highlighted for emphasis.
Args:
month_time_series (pd.Series): A time series with observations at one-month intervals.
ax (matplotlib.axes.Axes): A Matplotlib axis object to plot on.
step_size (int): The number of years to skip between plotted series.
highlight_year (int): The year to highlight in the overlay plot. Set to None for no highlight.
date_formatter (str): The format string for the date labels on the x-axis.
bg_alpha (float): The alpha transparency for the non-highlighted lines. Defaults to 0.3.
Returns:
matplotlib.axes.Axes: The updated axis object with the overlay plot.
pd.Index: The x-tick labels of the plot.
Raises:
AssertionError: If step_size is larger than the number of unique years in the time series.
"""
= month_time_series.index.year.unique()[::step_size]
target_years assert (len(target_years) > step_size), "Step size larger than index length"
= month_time_series.loc[str(
labs -step_size]):str(target_years[-1])].index.strftime(date_formatter)
target_years[for i in range(len(target_years)-(step_size-1)):
= 1 if (target_years[i + (step_size-1)]
a == highlight_year) else bg_alpha
= (
subset_time_series
month_time_seriesstr(target_years[i]): str(target_years[i + (step_size-1)])]
.loc[
.reset_index()
.set_index(labs)[month_time_series.name]
)= fr"${target_years[i]} \rightarrow {target_years[i + (step_size-1)]}$" if step_size != 1 else fr"${target_years[i]}$"
label =3,
subset_time_series.plot(linewidth=ax,
ax=a,
alpha=label)
labelreturn ax, labs
def show_all_xticks(ax: matplotlib.axes.Axes, labs: pd.Index) -> matplotlib.axes.Axes:
"""
Sets all x-ticks on a Matplotlib axis object and labels them with the provided list of labels.
This function ensures that all x-ticks are displayed and labeled as specified, making it easier to read and interpret the x-axis of the plot. The labels are displayed horizontally.
Parameters:
ax (matplotlib.axes.Axes): The axis object on which to set the x-ticks and labels.
labs (list of str): A list of labels to set on the x-axis. The number of labels should correspond to the number of ticks.
Returns:
matplotlib.axes.Axes: The modified axis object with updated x-ticks and labels.
"""
range(len(labs)))
ax.set_xticks(=0)
ax.set_xticklabels(labs, rotationreturn ax
def comma_formatter() -> ticker.FuncFormatter:
"""
Creates a custom Matplotlib tick formatter that formats axis ticks with commas.
This formatter converts axis tick values to integers and formats them with commas for thousands
separators. This can be useful for improving the readability of plots with large numerical values.
Returns:
--------
ticker.FuncFormatter:
A Matplotlib FuncFormatter object that applies the comma formatting to axis ticks.
"""
def comma(x: float, pos) -> str:
"""Format the tick value `x` with commas. The parameter `pos` is not used."""
return '{:,}'.format(int(x))
return ticker.FuncFormatter(comma)
def custom_legend(ax: matplotlib.axes.Axes,
str = None,
outside_loc: str, list] = "default",
order: Union[str = "",
title: int = 2,
linewidth: **kwargs) -> matplotlib.axes.Axes:
"""
Customize the legend location and order on a Matplotlib axis.
This function adjusts the position of the legend, optionally placing it outside the plot area,
and can reorder the legend entries based on specified criteria.
Args:
ax (matplotlib.axes.Axes): An existing Matplotlib axis object to which the legend belongs.
outside_loc (str, optional): Specifies the location of the legend outside the plot area.
Must be one of ["lower", "center", "upper", None].
Defaults to None.
order (str, optional): Determines the order of the legend entries.
Must be one of ["default", "reverse", "desc"].
"default" keeps the current order,
"reverse" reverses the current order,
"desc" orders entries by descending values.
Defaults to "default".
Returns:
matplotlib.axes.Axes: The axis object with the customized legend.
Raises:
AssertionError: If `outside_loc` is not in ["lower", "center", "upper", None].
"""
= ax.get_legend_handles_labels()
handles, labels if order == 'default':
pass
elif order == 'reverse':
= handles[::-1]
handles = labels[::-1]
labels elif order == 'desc':
= np.flip(np.argsort(np.array([line.get_ydata()[-1] for line in ax.lines if (len(line.get_ydata())>0 and line.get_label() in labels)])))
ordering = np.array(handles)[ordering].tolist()
handles = np.array(labels)[ordering].tolist()
labels elif isinstance(order, Iterable):
= {value: idx for idx, value in enumerate(labels)}
value_to_index = [value_to_index[value] for value in order]
indices = list(order)
labels = np.array(handles)[indices].tolist()
handles else:
raise Exception("Invalid Order")
= "legend_to_right loc must be None or in 'lower', 'center', or 'upper'"
error_msg assert outside_loc in ["lower", "center", "upper", None], error_msg
if outside_loc == "lower":
='lower left', bbox_to_anchor=(1, 0), **kwargs)
ax.legend(handles, labels, locelif outside_loc == "center":
='center left', bbox_to_anchor=(1, .5), **kwargs)
ax.legend(handles, labels, locelif outside_loc == "upper":
='upper left', bbox_to_anchor=(1, 1), **kwargs)
ax.legend(handles, labels, locelse:
**kwargs)
ax.legend(handles, labels, = ax.get_legend()
legend
legend.set_title(title)for line in legend.get_lines():
line.set_linewidth(linewidth)return ax
Using these functions, I’ll create a time series overlay that shows the impact of the pandemic on BART ridership in the context of previous three year trends, e.g. how does 2019-2021 compare to 2015-17? The effect is a plot that suggests that monthly BART ridership was on a trajectory similar to previous years leading up to the pandemic before sharply departing.
Note that I’m able to flexibly define the date formatting for the x-axis using a date_formatter
argument, and I can select to highlight a particular year with the highlight_year
argument.
= plt.subplots(figsize=(8, 4))
fig, ax = '%b\n\'%y'
date_formatter = time_overlay_plot(month_time_series.loc["2011":"2021"],
_, labs =ax, step_size=2,
ax=2021,
highlight_year=date_formatter)
date_formatter
ax.yaxis.set_major_formatter(comma_formatter())set(ylim=0, title="Three Year Trends in BART Ridership",
ax.="Total Monthly Riders")
ylabel== datetime(2020, 1, 1).strftime(date_formatter)),
ax.axvline(np.where(labs =":",
linestyle="First Covid Case in CA",
label="black",
color=3,
linewidth=.4)
alpha; ax.legend()
I’ll use the same time_overlay_plot
function to examine the recovery of the BART system in the post-pandemic period. This time, I’ll also use custom_legend
to easily move the legend outside of the plot area.
= pd.concat(
month_time_series_trunc "ME")['Riders'].sum().loc["2019-01-01": "2020-01-01"],
[df_rph_d.resample("ME")['Riders'].sum().loc["2021-01-01": "2024-01-01"]])
df_rph_d.resample(
= plt.subplots(figsize=(8, 4))
fig, ax = '%b'
date_formatter = time_overlay_plot(month_time_series_trunc, ax=ax, step_size=1,
_, labs =2019, date_formatter=date_formatter, bg_alpha=.5)
highlight_year
show_all_xticks(ax, labs)
ax.yaxis.set_major_formatter(comma_formatter())set(ylim=0, title="Post-Pandemic Recovery", ylabel="Total Monthly Riders")
ax.=.5)
ax.grid(alpha="center", order="desc",
custom_legend(ax, outside_loc="Year", linewidth=3)
title= ax.get_legend_handles_labels() handles, labels
Commute Patterns
I’ll now move onto examining some commute patterns. Here I’ll create some custom functions for building a colormap that I can use to consistently denote color-class relationships across multiple plots. Default behavior in matplotlib
and seaborn
typically leads to color-class relationships changing across charts of different types or involving different subsets of the data, e.g. in Chart 1, class A is Red, but in Chart 2, class A is Blue. My custom function, build_colormap
, creates a simple color-class relationship dictionary that defines the color-class relationships once and can be input to various seaborn
visuals.
I’ll define color-class relationships for the days of the week, but first I’ll do some setup in pandas
'day_of_week'] = df_rph['Date'].dt.day_name()
df_rph['Weekend'] = df_rph['day_of_week'].isin(['Saturday', 'Sunday'])
df_rph['Hour_12'] = pd.to_datetime(
df_rph['Hour'], format='%H').dt.strftime('%I:%M %p')
df_rph['Hour_12'] = df_rph['Hour_12'].str.replace(":00 ", "\n").str.strip("0")
df_rph['Rider_Thousands'] = df_rph['Riders'] / 1000 df_rph[
Now I’ll define the build_colormap
function and create my color mapping
Code
def build_colormap(series: pd.Series) -> dict:
"""
Build a colormap dictionary for a pandas Series.
This function creates a dictionary that maps each unique value in the input
Series to a unique color from the 'tab10' colormap provided by Seaborn. The
function ensures that the number of unique values in the Series is less than 10.
Parameters:
----------
series : pd.Series
The pandas Series for which to build the colormap. Each unique value in the
Series will be assigned a unique color.
Returns:
-------
dict
A dictionary where the keys are the unique values from the input Series
and the values are the corresponding colors from the 'tab10' colormap.
Raises:
------
AssertionError
If the number of unique values in the Series is 10 or more.
"""
= series.unique()
unique_set assert len(unique_set) < 10
= sns.color_palette("tab10")
colors = {}
colormap for d, c in zip(unique_set, colors[:len(unique_set)]):
= c
colormap[d] return colormap
Here is the resulting color-mapping dictionary
= build_colormap(df_rph['day_of_week'])
colormap colormap
{'Saturday': (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
'Sunday': (1.0, 0.4980392156862745, 0.054901960784313725),
'Monday': (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
'Tuesday': (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
'Wednesday': (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
'Thursday': (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
'Friday': (0.8901960784313725, 0.4666666666666667, 0.7607843137254902)}
and here’s a simple application of that colormap to examine the average riders-per-hour in the BART system.
= plt.subplots(figsize=(8, 4))
fig, ax =True, ax=ax)
sns.despine(top
sns.lineplot(=df_rph,
data='Hour_12',
x='Rider_Thousands',
y='day_of_week',
hue=colormap,
palette=3,
linewidth='mean',
estimator=None,
errorbar=ax)
axset(
ax.="Average Hourly Riders",
title="Average Riders (1,000s)",
ylabel="Hour")
xlabel="Day of the week", linewidth=2, order=list(colormap.keys()))
custom_legend(ax, title=.6)
ax.grid(alpha fig.tight_layout()
I can see the degree to which BART seems to have a lot of its ridership concentrated in commute times. Now lets say that I want to reproduce this plot repeatedly for different years, so see how this general shape changes before and after the pandemic. This is going to be a grid plot, where each individual plot in the grid represents one year of data. I’ll create a flexible function for grid-plotting, grid_plot
, below, which allows for one to pass in a seaborn
plot, and **kwargs
for that plotting function, and repeat that plot across some grouping variable, group_var
.
Code
def grid_plot(grid_subset: pd.DataFrame,
callable,
seaborn_func: int,
rows: int,
cols: str,
group_var: tuple = (10, 5),
figsize: str = "lower",
legend_loc: **kwargs):
= grid_subset[group_var].unique()
facet = facet.reshape(rows, cols)
facet = plt.subplots(
fig, axes 0], facet.shape[1], sharex=True, sharey=True, figsize=figsize)
facet.shape[if rows > 1:
for m in range(facet.shape[0]):
for n in range(facet.shape[1]):
= facet[m, n]
group = grid_subset[grid_subset[group_var] == group]
subset if legend_loc == 'lower':
= True if (m+1 == rows) and (n+1 == cols) else False
legend elif legend_loc == "upper":
= True if (m+1 == 1) and (n+1 == cols) else False
legend else:
= False
legend
seaborn_func(=subset,
data=axes[m, n],
ax=legend,
legend**kwargs)
set(ylabel=None, xlabel=None, title=group)
axes[m, n].
axes[m, n].grid()else:
for n in range(facet.shape[1]):
= facet[n]
group = grid_subset[grid_subset[group_var] == group]
subset if legend_loc == 'lower':
= True if (m+1 == rows) and (n+1 == cols) else False
legend elif legend_loc == "upper":
= True if (m+1 == 1) and (n+1 == cols) else False
legend else:
= False
legend
seaborn_func(=subset,
data=axes[m, n],
ax=legend,
legend**kwargs)
set(ylabel=None, xlabel=None, title=group)
axes[n].
axes[n].grid()return fig, axes
For an application, I’ll create a dataset from 2018-2024 where I want to display one chart per year in a grid pattern
= df_rph.copy()[(df_rph['Date'] > "2018")
grid_subset & (df_rph['Date'] < "2024")]
'Year'] = grid_subset['Date'].dt.year grid_subset[
Now I can use my function, grid_plot
, to specify a plot type, a grid shape, and pass seaborn
**kwargs
to the underlying plotting function.
= grid_plot(grid_subset,
fig, axes
sns.lineplot,=2, cols=3,
rows='Year',
group_var="lower",
legend_loc='Hour',
x='Rider_Thousands',
y='day_of_week',
hue=(8, 5),
figsize=colormap,
palette='mean',
estimator=2,
linewidth=None)
errorbar1, 2], outside_loc="upper",
custom_legend(axes[="Day of the week", linewidth=2, order=list(colormap.keys()))
title"Hour of Day")
fig.supxlabel("Average Riders (1,000s)")
fig.supylabel(
fig.tight_layout()"Hourly Ridership Trends Across Years", y=1.05); fig.suptitle(
The grid presents an interesting pattern in so far as the general level of ridership drops across all hours, but the raw plot seems to suggest that the gap between peak commute ridership levels and off-peak ridership seems to have shrunken post-pandemic. In late 2023, BART instituted scheduling changes to reflect their belief that demand has shifted in the post-pandemic towards rides outside of peak commute times.
We can check that with a simple time series plot that relies on my custom_legend
function to move the legend and resize the lines within it.
= (df_rph['Hour'].isin([7, 8, 9, 16, 17, 18, 19]) &
is_commute ~df_rph['Weekend']).replace({
True: 'Riders during Peak Commute times',
False: 'Riders during Off-Peak times'
})
= plt.subplots()
fig, ax 'Date'] < "2024-05-01"]
(df_rph[df_rph[=pd.Grouper(key='Date', freq='ME'),
.pivot_table(index=is_commute,
columns="Riders",
values='sum')
aggfunc=2, ax=ax)
).plot(linewidth
ax.yaxis.set_major_formatter(comma_formatter())=.4)
ax.grid(alphaset(ylabel="Riders", title="Trends in Commute versus Non-Commute Ridership")
ax.="Travel Time", linewidth=3, loc="lower left"); custom_legend(ax, title
It does look like off peak and peak commute ridership had a growing gap in the pre-pandemic, but have been essentially equal in the post pandemic, interesting!
Citation
@online{amerkhanian2024,
author = {Amerkhanian, Peter},
title = {Building {Time-series} {Utilities} for `Matplotlib`},
date = {2024-06-16},
url = {https://peter-amerkhanian.com/posts/time-series-matplotlib/},
langid = {en}
}