Creating a Bar Chart Race Animation in Python with Matplotlib

In this tutorial, you’ll learn how to create a bar chart race animation such as the one below using the matplotlib data visualization library in python. This post is rendered in the style of a Jupyter Notebook on the Dunder Data blog.

bar chart race

bar_chart_race python package

Along with this tutorial is the release of the python package bar_chart_racethat automates the process of making these animations. This post explains the procedure from scratch.

What is a bar chart race?

A bar chart race is an animated sequence of bars that show data values at different moments in time. The bars re-position themselves at each time period so that they remain in order (either ascending or descending).

Transition bars smoothly between time periods

The trick to making a bar chart race is to transition the bars slowly to their new position when their order changes, allowing you to easily track the movements.

COVID-19 deaths data

For this bar chart race, we’ll use a small dataset produced by John Hopkins University containing the total deaths by date for six countries during the currently ongoing coronavirus pandemic. Let’s read it in now.

import pandas as pd
df = pd.read\_csv('data/covid19.csv', index\_col='date', 
                  parse\_dates=['date'])
df.tail()
Enter fullscreen mode Exit fullscreen mode

png

Must use ‘wide’ data

For this tutorial, the data must be in ‘wide’ form where:

  • Every row represents a single period of time
  • Each column holds the value for a particular category
  • The index contains the time component (optional)

Individual bar charts for specific dates

Let’s begin by creating a single static bar chart for the specific date of March 29, 2020. First, we select the data as a Series.

s = df.loc['2020-03-29']
s

China 3304.0
USA 2566.0
Italy 10779.0
UK 1231.0
Iran 2640.0
Spain 6803.0
Name: 2020-03-29 00:00:00, dtype: float64
Enter fullscreen mode Exit fullscreen mode

We’ll make a horizontal bar chart using the country names as the y-values and total deaths as the x-values (width of bars). Every bar will be a different color from the ‘Dark2’ colormap.

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(4, 2.5), dpi=144)
colors = plt.cm.Dark2(range(6))
y = s.index
width = s.values
ax.barh(y=y, width=width, color=colors);
Enter fullscreen mode Exit fullscreen mode

png

The function below changes several properties of the axes to make it look nicer.

def nice\_axes(ax):
    ax.set\_facecolor('.8')
    ax.tick\_params(labelsize=8, length=0)
    ax.grid(True, axis='x', color='white')
    ax.set\_axisbelow(True)
    [spine.set\_visible(False) for spine in ax.spines.values()]

nice\_axes(ax)
fig
Enter fullscreen mode Exit fullscreen mode

png

Plot three consecutive days ordering the bars

For a bar chart race, the bars are often ordered from largest to smallest with the largest at the top. Here, we plot three days of data sorting each one first.

fig, ax\_array = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5), dpi=144, tight\_layout=True)
dates = ['2020-03-29', '2020-03-30', '2020-03-31']
for ax, date in zip(ax\_array, dates):
    s = df.loc[date].sort\_values()
    ax.barh(y=s.index, width=s.values, color=colors)
    ax.set\_title(date, fontsize='smaller')
    nice\_axes(ax)
Enter fullscreen mode Exit fullscreen mode

png

Countries change color

Although the bars are ordered properly, the countries do not keep their original color when changing places in the graph. Notice that the USA begins as the fifth bar and moves up one position each date, changing colors each time. This is due to their position changing in the Series when sorting.

Don’t sort — rank instead!

Instead of sorting, use the rank method to find the numeric ranking of each country for each day. We use the 'first' method of ranking so that each numeric rank is a unique integer. By default, the method is 'average' which ranks ties with the same value causing overlapping bars. Let's see the ranking for the March 29, 2020.

df.loc['2020-03-29'].rank(method='first')

China 4.0
USA 2.0
Italy 6.0
UK 1.0
Iran 3.0
Spain 5.0
Name: 2020-03-29 00:00:00, dtype: float64
Enter fullscreen mode Exit fullscreen mode

We now use this rank as the y-values. The order of the data in the Series never changes this way, ensuring that countries remain the same color regardless of their rank.

fig, ax\_array = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5), dpi=144, tight\_layout=True)
dates = ['2020-03-29', '2020-03-30', '2020-03-31']
for ax, date in zip(ax\_array, dates):
    s = df.loc[date]
    y = df.loc[date].rank(method='first').values
    ax.barh(y=y, width=s.values, color=colors, tick\_label=s.index)
    ax.set\_title(date, fontsize='smaller')
    nice\_axes(ax)
Enter fullscreen mode Exit fullscreen mode

png

How to smoothly transition?

Using each day as a single frame in an animation won’t work well as it doesn’t capture the transition from one time period to the next. In order to transition the bars that change positions, we’ll need to add extra rows of data between the dates that we do have. Let’s first select the three dates above as a DataFrame.

df2 = df.loc['2020-03-29':'2020-03-31']
df2
Enter fullscreen mode Exit fullscreen mode

png

It’s easier to insert an exact number of new rows when using the default index — integers beginning at 0. Alternatively, if you do have a datetime in the index as we do here, you can use the asfreq method, which is explained at the end of this post. Use the reset_index method to get a default index and to place the dates as a column again.

df2 = df2.reset\_index()
df2
Enter fullscreen mode Exit fullscreen mode

png

Choose number of steps between each date

We want to insert new rows between the first and second rows and between the second and third rows. Begin by multiplying the index by the number of steps to transition from one time period to the next. We use 5 in this example.

df2.index = df2.index \* 5
df2
Enter fullscreen mode Exit fullscreen mode

png

Expand DataFrame with reindex

To insert the additional rows, pass the reindex method a sequence of all integers beginning at 0 to the last integer (10 in this case). pandas inserts new rows of all missing values for every index not in the current DataFrame.

last\_idx = df2.index[-1] + 1
df\_expanded = df2.reindex(range(last\_idx))
df\_expanded
Enter fullscreen mode Exit fullscreen mode

png

The date for the missing rows is the same for each. Let’s fill them in using the last known value with the fillna method and set it as the index again.

df\_expanded['date'] = df\_expanded['date'].fillna(method='ffill')
df\_expanded = df\_expanded.set\_index('date')
df\_expanded
Enter fullscreen mode Exit fullscreen mode

png

Rank each row

We also need a similar DataFrame that contains the rank of each country by row. Most pandas methods work down each column by default. Set axis to 1 to change the direction of the operation so that values in each row are ranked against each other.

df\_rank\_expanded = df\_expanded.rank(axis=1, method='first')
df\_rank\_expanded
Enter fullscreen mode Exit fullscreen mode

png

Linear interpolate missing values

The interpolate method can fill in the missing values in a variety of ways. By default, it uses linear interpolation and works column-wise.

df\_expanded = df\_expanded.interpolate()
df\_expanded
Enter fullscreen mode Exit fullscreen mode

png

We also need to interpolate the ranking.

df\_rank\_expanded = df\_rank\_expanded.interpolate()
df\_rank\_expanded
Enter fullscreen mode Exit fullscreen mode

png

Plot each step of the transition

The interpolated ranks will serve as the new position of the bars along the y-axis. Here, we’ll plot each step from the first to the second day where Iran and the USA change place.

fig, ax\_array = plt.subplots(nrows=1, ncols=6, figsize=(12, 2), 
                             dpi=144, tight\_layout=True)
labels = df\_expanded.columns
for i, ax in enumerate(ax\_array.flatten()):
    y = df\_rank\_expanded.iloc[i]
    width = df\_expanded.iloc[i]
    ax.barh(y=y, width=width, color=colors, tick\_label=labels)
    nice\_axes(ax)
ax\_array[0].set\_title('2020-03-29')
ax\_array[-1].set\_title('2020-03-30');
Enter fullscreen mode Exit fullscreen mode

png

The next day’s transition is plotted below.

fig, ax\_array = plt.subplots(nrows=1, ncols=6, figsize=(12, 2), 
                             dpi=144, tight\_layout=True)
labels = df\_expanded.columns
for i, ax in enumerate(ax\_array.flatten(), start=5):
    y = df\_rank\_expanded.iloc[i]
    width = df\_expanded.iloc[i]
    ax.barh(y=y, width=width, color=colors, tick\_label=labels)
    nice\_axes(ax)
ax\_array[0].set\_title('2020-03-30')
ax\_array[-1].set\_title('2020-03-31');
Enter fullscreen mode Exit fullscreen mode

png

Write a function to prepare all of the data

We can copy and paste the code above into a function to automate the process of preparing any data for the bar chart race. Then use it to create the two final DataFrames needed for plotting.

def prepare\_data(df, steps=5):
    df = df.reset\_index()
    df.index = df.index \* steps
    last\_idx = df.index[-1] + 1
    df\_expanded = df.reindex(range(last\_idx))
    df\_expanded['date'] = df\_expanded['date'].fillna(method='ffill')
    df\_expanded = df\_expanded.set\_index('date')
    df\_rank\_expanded = df\_expanded.rank(axis=1, method='first')
    df\_expanded = df\_expanded.interpolate()
    df\_rank\_expanded = df\_rank\_expanded.interpolate()
    return df\_expanded, df\_rank\_expanded

df\_expanded, df\_rank\_expanded = prepare\_data(df)
df\_expanded.head()
Enter fullscreen mode Exit fullscreen mode

png

df\_rank\_expanded.head()
Enter fullscreen mode Exit fullscreen mode

png

Animation

We are now ready to create the animation. Each row represents a single frame in our animation and will slowly transition the bars’ y-value location and width from one day to the next.

The simplest way to do animation in matplotlib is to use FuncAnimation. You must define a function that updates the matplotlib axes object each frame. Because the axes object keeps all of the previous bars, we remove them in the beginning of the update function. The rest of the function is identical to the plotting from above. This function will be passed the index of the frame as an integer. We also set the title to have the current date.

Optionally, you can define a function that initializes the axes. Below, init clears the previous axes of all objects and then resets its nice properties.

Pass the figure (containing your axes), the update and init functions, and number of frames to FuncAnimation. We also pass the number of milliseconds between each frame, which is used when creating HTML. We use 100 milliseconds per frame equating to 500 per day (half of a second).

The figure and axes are created separately below so they do not get output in a Jupyter Notebook, which automatically happens if you call plt.subplots.

from matplotlib.animation import FuncAnimation

def init():
    ax.clear()
    nice\_axes(ax)
    ax.set\_ylim(.2, 6.8)

def update(i):
    for bar in ax.containers:
        bar.remove()
    y = df\_rank\_expanded.iloc[i]
    width = df\_expanded.iloc[i]
    ax.barh(y=y, width=width, color=colors, tick\_label=labels)
    date\_str = df\_expanded.index[i].strftime('%B %-d, %Y')
    ax.set\_title(f'COVID-19 Deaths by Country - {date\_str}', fontsize='smaller')

fig = plt.Figure(figsize=(4, 2.5), dpi=144)
ax = fig.add\_subplot()
anim = FuncAnimation(fig=fig, func=update, init\_func=init, frames=len(df\_expanded), 
                     interval=100, repeat=False)
Enter fullscreen mode Exit fullscreen mode

Return animation HTML or save to disk

Call the to_html5_video method to return the animation as an HTML string and then embed it in the notebook with help from the IPython.display module.

from IPython.display import HTML
html = anim.to\_html5\_video()
HTML(html)
Enter fullscreen mode Exit fullscreen mode

You can save the animation to disk as an mp4 file using the save method. Since we have an init function, we don't have to worry about clearing our axes and resetting the limits. It will do it for us.

anim.save('media/covid19.mp4')
Enter fullscreen mode Exit fullscreen mode

Using bar_chart_race

I created the bar_chart_race python package to automate this process. It creates bar chart races from wide pandas DataFrames. Install with pip install bar_chart_race.

import bar\_chart\_race as bcr
html = bcr.bar\_chart\_race(df, figsize=(4, 2.5), title='COVID-19 Deaths by Country')
HTML(html)
Enter fullscreen mode Exit fullscreen mode

Using the asfreq

If you are familiar with pandas, you might know that the asfreq method can be used to insert new rows. Let's reselect the last three days of March again to show how it works.

df2 = df.loc['2020-03-29':'2020-03-31']
df2
Enter fullscreen mode Exit fullscreen mode

png

Inserting new rows is actually easier with asfreq. We just need to supply it a date offset that is a multiple of 24 hours. Here, we insert a new row every 6 hours.

df2.asfreq('6h')
Enter fullscreen mode Exit fullscreen mode

png

Inserting a specific number of rows is a little trickier, but possible by creating a date range first, which allows you to specify the total number of periods, which you must calculate.

num\_periods = (len(df2) - 1) \* 5 + 1
dr = pd.date\_range(start='2020-03-29', end='2020-03-31', 
                   periods=num\_periods)
dr

DatetimeIndex(['2020-03-29 00:00:00', '2020-03-29 04:48:00',
               '2020-03-29 09:36:00', '2020-03-29 14:24:00',
               '2020-03-29 19:12:00', '2020-03-30 00:00:00',
               '2020-03-30 04:48:00', '2020-03-30 09:36:00',
               '2020-03-30 14:24:00', '2020-03-30 19:12:00',
               '2020-03-31 00:00:00'],
              dtype='datetime64[ns]', freq=None)
Enter fullscreen mode Exit fullscreen mode

Then pass this date range to reindex to achieve the same result.

df2.reindex(dr)
Enter fullscreen mode Exit fullscreen mode

png

We can use this procedure on all of our data.

num\_periods = (len(df) - 1) \* 5 + 1
dr = pd.date\_range(start=df.index[0], end=df.index[-1], periods=num\_periods)
df\_expanded = df.reindex(dr)
df\_rank\_expanded = df\_expanded.rank(axis=1).interpolate()
df\_expanded = df\_expanded.interpolate()
df\_expanded.iloc[160:166]
Enter fullscreen mode Exit fullscreen mode

png

df\_rank\_expanded.iloc[160:166]
Enter fullscreen mode Exit fullscreen mode

png

One line?

It’s possible to do all of the analysis in a single ugly line of code.

df\_one = df.reset\_index() \
           .reindex([i / 5 for i in range(len(df) \* 5 - 4)]) \
           .reset\_index(drop=True) \
           .pipe(lambda x: pd.concat(
                                [x, x.iloc[:, 1:].rank(axis=1)], 
                                axis=1, keys=['values', 'ranks'])) \
           .interpolate() \
           .fillna(method='ffill') \
           .set\_index(('values', 'date')) \
           .rename\_axis(index='date')
df\_one.head()
Enter fullscreen mode Exit fullscreen mode

png

Master Data Analysis with Python

If you are looking for a single, comprehensive resources to master pandas, matplotlib, and seaborn, check out my book Master Data Analysis with Python. It contains 800 pages and 350 exercises with detailed solutions. If you want to be a trusted source to do data analysis using Python, this book will ensure you get there.


Logo

华为、百度、京东云现已入驻,来创建你的专属开发者社区吧!

更多推荐