Creating a Bar Chart Race Animation in Python with Matplotlib
In this tutorial, you’ll learn how to create a bar chart race animation such as the one below using the matplotlib data visualization library in python. This post is rendered in the style of a Jupyter Notebook on the Dunder Data blog.
bar_chart_race python package
Along with this tutorial is the release of the python package bar_chart_racethat automates the process of making these animations. This post explains the procedure from scratch.
What is a bar chart race?
A bar chart race is an animated sequence of bars that show data values at different moments in time. The bars re-position themselves at each time period so that they remain in order (either ascending or descending).
Transition bars smoothly between time periods
The trick to making a bar chart race is to transition the bars slowly to their new position when their order changes, allowing you to easily track the movements.
COVID-19 deaths data
For this bar chart race, we’ll use a small dataset produced by John Hopkins University containing the total deaths by date for six countries during the currently ongoing coronavirus pandemic. Let’s read it in now.
import pandas as pd
df = pd.read\_csv('data/covid19.csv', index\_col='date',
parse\_dates=['date'])
df.tail()
Must use ‘wide’ data
For this tutorial, the data must be in ‘wide’ form where:
- Every row represents a single period of time
- Each column holds the value for a particular category
- The index contains the time component (optional)
Individual bar charts for specific dates
Let’s begin by creating a single static bar chart for the specific date of March 29, 2020. First, we select the data as a Series.
s = df.loc['2020-03-29']
s
China 3304.0
USA 2566.0
Italy 10779.0
UK 1231.0
Iran 2640.0
Spain 6803.0
Name: 2020-03-29 00:00:00, dtype: float64
We’ll make a horizontal bar chart using the country names as the y-values and total deaths as the x-values (width of bars). Every bar will be a different color from the ‘Dark2’ colormap.
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(4, 2.5), dpi=144)
colors = plt.cm.Dark2(range(6))
y = s.index
width = s.values
ax.barh(y=y, width=width, color=colors);
The function below changes several properties of the axes to make it look nicer.
def nice\_axes(ax):
ax.set\_facecolor('.8')
ax.tick\_params(labelsize=8, length=0)
ax.grid(True, axis='x', color='white')
ax.set\_axisbelow(True)
[spine.set\_visible(False) for spine in ax.spines.values()]
nice\_axes(ax)
fig
Plot three consecutive days ordering the bars
For a bar chart race, the bars are often ordered from largest to smallest with the largest at the top. Here, we plot three days of data sorting each one first.
fig, ax\_array = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5), dpi=144, tight\_layout=True)
dates = ['2020-03-29', '2020-03-30', '2020-03-31']
for ax, date in zip(ax\_array, dates):
s = df.loc[date].sort\_values()
ax.barh(y=s.index, width=s.values, color=colors)
ax.set\_title(date, fontsize='smaller')
nice\_axes(ax)
Countries change color
Although the bars are ordered properly, the countries do not keep their original color when changing places in the graph. Notice that the USA begins as the fifth bar and moves up one position each date, changing colors each time. This is due to their position changing in the Series when sorting.
Don’t sort — rank instead!
Instead of sorting, use the rank method to find the numeric ranking of each country for each day. We use the 'first' method of ranking so that each numeric rank is a unique integer. By default, the method is 'average' which ranks ties with the same value causing overlapping bars. Let's see the ranking for the March 29, 2020.
df.loc['2020-03-29'].rank(method='first')
China 4.0
USA 2.0
Italy 6.0
UK 1.0
Iran 3.0
Spain 5.0
Name: 2020-03-29 00:00:00, dtype: float64
We now use this rank as the y-values. The order of the data in the Series never changes this way, ensuring that countries remain the same color regardless of their rank.
fig, ax\_array = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5), dpi=144, tight\_layout=True)
dates = ['2020-03-29', '2020-03-30', '2020-03-31']
for ax, date in zip(ax\_array, dates):
s = df.loc[date]
y = df.loc[date].rank(method='first').values
ax.barh(y=y, width=s.values, color=colors, tick\_label=s.index)
ax.set\_title(date, fontsize='smaller')
nice\_axes(ax)
How to smoothly transition?
Using each day as a single frame in an animation won’t work well as it doesn’t capture the transition from one time period to the next. In order to transition the bars that change positions, we’ll need to add extra rows of data between the dates that we do have. Let’s first select the three dates above as a DataFrame.
df2 = df.loc['2020-03-29':'2020-03-31']
df2
It’s easier to insert an exact number of new rows when using the default index — integers beginning at 0. Alternatively, if you do have a datetime in the index as we do here, you can use the asfreq method, which is explained at the end of this post. Use the reset_index method to get a default index and to place the dates as a column again.
df2 = df2.reset\_index()
df2
Choose number of steps between each date
We want to insert new rows between the first and second rows and between the second and third rows. Begin by multiplying the index by the number of steps to transition from one time period to the next. We use 5 in this example.
df2.index = df2.index \* 5
df2
Expand DataFrame with reindex
To insert the additional rows, pass the reindex method a sequence of all integers beginning at 0 to the last integer (10 in this case). pandas inserts new rows of all missing values for every index not in the current DataFrame.
last\_idx = df2.index[-1] + 1
df\_expanded = df2.reindex(range(last\_idx))
df\_expanded
The date for the missing rows is the same for each. Let’s fill them in using the last known value with the fillna method and set it as the index again.
df\_expanded['date'] = df\_expanded['date'].fillna(method='ffill')
df\_expanded = df\_expanded.set\_index('date')
df\_expanded
Rank each row
We also need a similar DataFrame that contains the rank of each country by row. Most pandas methods work down each column by default. Set axis to 1 to change the direction of the operation so that values in each row are ranked against each other.
df\_rank\_expanded = df\_expanded.rank(axis=1, method='first')
df\_rank\_expanded
Linear interpolate missing values
The interpolate method can fill in the missing values in a variety of ways. By default, it uses linear interpolation and works column-wise.
df\_expanded = df\_expanded.interpolate()
df\_expanded
We also need to interpolate the ranking.
df\_rank\_expanded = df\_rank\_expanded.interpolate()
df\_rank\_expanded
Plot each step of the transition
The interpolated ranks will serve as the new position of the bars along the y-axis. Here, we’ll plot each step from the first to the second day where Iran and the USA change place.
fig, ax\_array = plt.subplots(nrows=1, ncols=6, figsize=(12, 2),
dpi=144, tight\_layout=True)
labels = df\_expanded.columns
for i, ax in enumerate(ax\_array.flatten()):
y = df\_rank\_expanded.iloc[i]
width = df\_expanded.iloc[i]
ax.barh(y=y, width=width, color=colors, tick\_label=labels)
nice\_axes(ax)
ax\_array[0].set\_title('2020-03-29')
ax\_array[-1].set\_title('2020-03-30');
The next day’s transition is plotted below.
fig, ax\_array = plt.subplots(nrows=1, ncols=6, figsize=(12, 2),
dpi=144, tight\_layout=True)
labels = df\_expanded.columns
for i, ax in enumerate(ax\_array.flatten(), start=5):
y = df\_rank\_expanded.iloc[i]
width = df\_expanded.iloc[i]
ax.barh(y=y, width=width, color=colors, tick\_label=labels)
nice\_axes(ax)
ax\_array[0].set\_title('2020-03-30')
ax\_array[-1].set\_title('2020-03-31');
Write a function to prepare all of the data
We can copy and paste the code above into a function to automate the process of preparing any data for the bar chart race. Then use it to create the two final DataFrames needed for plotting.
def prepare\_data(df, steps=5):
df = df.reset\_index()
df.index = df.index \* steps
last\_idx = df.index[-1] + 1
df\_expanded = df.reindex(range(last\_idx))
df\_expanded['date'] = df\_expanded['date'].fillna(method='ffill')
df\_expanded = df\_expanded.set\_index('date')
df\_rank\_expanded = df\_expanded.rank(axis=1, method='first')
df\_expanded = df\_expanded.interpolate()
df\_rank\_expanded = df\_rank\_expanded.interpolate()
return df\_expanded, df\_rank\_expanded
df\_expanded, df\_rank\_expanded = prepare\_data(df)
df\_expanded.head()
df\_rank\_expanded.head()
Animation
We are now ready to create the animation. Each row represents a single frame in our animation and will slowly transition the bars’ y-value location and width from one day to the next.
The simplest way to do animation in matplotlib is to use FuncAnimation. You must define a function that updates the matplotlib axes object each frame. Because the axes object keeps all of the previous bars, we remove them in the beginning of the update function. The rest of the function is identical to the plotting from above. This function will be passed the index of the frame as an integer. We also set the title to have the current date.
Optionally, you can define a function that initializes the axes. Below, init clears the previous axes of all objects and then resets its nice properties.
Pass the figure (containing your axes), the update and init functions, and number of frames to FuncAnimation. We also pass the number of milliseconds between each frame, which is used when creating HTML. We use 100 milliseconds per frame equating to 500 per day (half of a second).
The figure and axes are created separately below so they do not get output in a Jupyter Notebook, which automatically happens if you call plt.subplots.
from matplotlib.animation import FuncAnimation
def init():
ax.clear()
nice\_axes(ax)
ax.set\_ylim(.2, 6.8)
def update(i):
for bar in ax.containers:
bar.remove()
y = df\_rank\_expanded.iloc[i]
width = df\_expanded.iloc[i]
ax.barh(y=y, width=width, color=colors, tick\_label=labels)
date\_str = df\_expanded.index[i].strftime('%B %-d, %Y')
ax.set\_title(f'COVID-19 Deaths by Country - {date\_str}', fontsize='smaller')
fig = plt.Figure(figsize=(4, 2.5), dpi=144)
ax = fig.add\_subplot()
anim = FuncAnimation(fig=fig, func=update, init\_func=init, frames=len(df\_expanded),
interval=100, repeat=False)
Return animation HTML or save to disk
Call the to_html5_video method to return the animation as an HTML string and then embed it in the notebook with help from the IPython.display module.
from IPython.display import HTML
html = anim.to\_html5\_video()
HTML(html)
You can save the animation to disk as an mp4 file using the save method. Since we have an init function, we don't have to worry about clearing our axes and resetting the limits. It will do it for us.
anim.save('media/covid19.mp4')
Using bar_chart_race
I created the bar_chart_race python package to automate this process. It creates bar chart races from wide pandas DataFrames. Install with pip install bar_chart_race.
import bar\_chart\_race as bcr
html = bcr.bar\_chart\_race(df, figsize=(4, 2.5), title='COVID-19 Deaths by Country')
HTML(html)
Using the asfreq
If you are familiar with pandas, you might know that the asfreq method can be used to insert new rows. Let's reselect the last three days of March again to show how it works.
df2 = df.loc['2020-03-29':'2020-03-31']
df2
Inserting new rows is actually easier with asfreq. We just need to supply it a date offset that is a multiple of 24 hours. Here, we insert a new row every 6 hours.
df2.asfreq('6h')
Inserting a specific number of rows is a little trickier, but possible by creating a date range first, which allows you to specify the total number of periods, which you must calculate.
num\_periods = (len(df2) - 1) \* 5 + 1
dr = pd.date\_range(start='2020-03-29', end='2020-03-31',
periods=num\_periods)
dr
DatetimeIndex(['2020-03-29 00:00:00', '2020-03-29 04:48:00',
'2020-03-29 09:36:00', '2020-03-29 14:24:00',
'2020-03-29 19:12:00', '2020-03-30 00:00:00',
'2020-03-30 04:48:00', '2020-03-30 09:36:00',
'2020-03-30 14:24:00', '2020-03-30 19:12:00',
'2020-03-31 00:00:00'],
dtype='datetime64[ns]', freq=None)
Then pass this date range to reindex to achieve the same result.
df2.reindex(dr)
We can use this procedure on all of our data.
num\_periods = (len(df) - 1) \* 5 + 1
dr = pd.date\_range(start=df.index[0], end=df.index[-1], periods=num\_periods)
df\_expanded = df.reindex(dr)
df\_rank\_expanded = df\_expanded.rank(axis=1).interpolate()
df\_expanded = df\_expanded.interpolate()
df\_expanded.iloc[160:166]
df\_rank\_expanded.iloc[160:166]
One line?
It’s possible to do all of the analysis in a single ugly line of code.
df\_one = df.reset\_index() \
.reindex([i / 5 for i in range(len(df) \* 5 - 4)]) \
.reset\_index(drop=True) \
.pipe(lambda x: pd.concat(
[x, x.iloc[:, 1:].rank(axis=1)],
axis=1, keys=['values', 'ranks'])) \
.interpolate() \
.fillna(method='ffill') \
.set\_index(('values', 'date')) \
.rename\_axis(index='date')
df\_one.head()
Master Data Analysis with Python
If you are looking for a single, comprehensive resources to master pandas, matplotlib, and seaborn, check out my book Master Data Analysis with Python. It contains 800 pages and 350 exercises with detailed solutions. If you want to be a trusted source to do data analysis using Python, this book will ensure you get there.
所有评论(0)