▦ How to Create Compound Charts using Altair?


Example of the selection and compound chart

By the end of this post, you will be able to do create a chart like above with your dataset. I used similar plots in evaluating crowd density estimation models post.

altair layers and crossfilter

In this notebook, we are going to plot the stock prices to demonstrate the following in Altair.

How to:

  • create compound charts
  • create selections i.e cross filtering capabilities

The data was downloaded from Yahoo Finance using R.

In [905]:
import pandas as pd
import altair as alt

Load Data

We will load the data and print some of its rows.

In [906]:
df = pd.read_csv("data.csv", parse_dates=['date'])
df.head(5)
Out[906]:
Unnamed: 0symboldateopenhighlowclosevolumeadjusted
01AAPL2001-01-020.2656250.2723210.2600450.265625452312000.00.229537
12AAPL2001-01-030.2589290.2979910.2578130.292411817073600.00.252684
23AAPL2001-01-040.3239400.3303570.3002230.304688739396000.00.263292
34AAPL2001-01-050.3024550.3102680.2868300.292411412356000.00.252684
45AAPL2001-01-080.3024550.3032920.2845980.295759373699200.00.255577

Print the column info and associated datatype

In [907]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 18523 entries, 0 to 18522
Data columns (total 9 columns):
Unnamed: 0    18523 non-null int64
symbol        18523 non-null object
date          18523 non-null datetime64[ns]
open          18523 non-null float64
high          18523 non-null float64
low           18523 non-null float64
close         18523 non-null float64
volume        18523 non-null float64
adjusted      18523 non-null float64
dtypes: datetime64[ns](1), float64(6), int64(1), object(1)
memory usage: 1.3+ MB

How many ticker/symbols?

In [908]:
df['symbol'].unique()
Out[908]:
array(['AAPL', 'GOOG', 'AMZN', 'TSLA', 'BTC-USD'], dtype=object)

What is the data range?

In [909]:
print(f"Min Date: {df.date.min()}, Max Date: {df.date.max()}")
Min Date: 2001-01-02 00:00:00, Max Date: 2020-07-24 00:00:00
In [910]:
# We are only going to use 2020 data
df = df.loc[df.date > "2020-01-01", :]

Print statistics of the numerical columns

In [911]:
# print stats
df.describe()
Out[911]:
Unnamed: 0openhighlowclosevolumeadjusted
count769.000000769.000000769.000000769.000000769.0000007.690000e+02769.000000
mean12942.5357612991.7457473047.5342632936.7107972996.0936768.831951e+092996.039726
std5031.2847513509.5257733572.5685453445.2376023513.7399491.581048e+103513.784838
min4780.00000057.02000057.12500053.15250056.0924999.295000e+0555.840385
25%8841.000000140.199997147.953995136.854004141.1260074.121900e+06141.126007
50%13812.0000001449.1600341465.4300541432.4699711451.8599858.900750e+071451.859985
75%18331.0000006245.6245126504.5151375920.0859386242.1938481.571397e+106242.193848
max18523.00000010323.96093810457.62695310202.38769510326.0546887.415677e+1010326.054688
In [912]:
def find_data(df, symbol):
    result = df.copy().loc[df['symbol'] == symbol, :]
    return result.loc[:, ['symbol', 'date', 'adjusted', 'pct_change', 'volume']]

Actually, we are only going to use APPL in this notebook.

In [913]:
aapl_df = find_data(df, 'AAPL')
aapl_df.head()
Out[913]:
symboldateadjustedpct_changevolume
4779AAPL2020-01-0274.573036NaN135480400.0
4780AAPL2020-01-0373.848030NaN146322800.0
4781AAPL2020-01-0674.436470NaN118387200.0
4782AAPL2020-01-0774.086395NaN108872000.0
4783AAPL2020-01-0875.278160NaN132079200.0

Plots

We will start by creating individual plots and then add the compound ones.

Daily Volume

We start by plotting the daily volume of Apple in 2020

In [914]:
aapl_bar = alt.Chart(aapl_df).mark_bar().encode(
    alt.X('date:T', title=""),
    alt.Y('volume', title="Volume")
).properties(    
    width=700,
    title="Apple: Daily Volume 2020"
)
aapl_bar
Out[914]:

Similarly, we can plot the price (adjusted price) too.

In [915]:
aapl_line = alt.Chart(aapl_df).mark_line().encode(
    alt.X('date:T', title=""),
    alt.Y('adjusted', title="Price")
).properties(    
    width=700,
    height=200,
    title="Apple: Daily Price 2020"
)
aapl_line
Out[915]:

Compound Chart

Let us add them together using layers. Here, we are plotting Price and Volume in the same chart but separate y-axis (y1, y2).

In [916]:
base = alt.Chart(aapl_df).mark_line().encode(
    alt.X('date:T', title=""),
)

bar = base.mark_bar().encode(
    alt.Y('volume', title="Volume"),    
    alt.Tooltip(['date', 'volume', 'adjusted'])    
)

line = base.mark_line(color='orange').encode(
    alt.Y('adjusted', title="Price"),
    alt.Tooltip(['date', 'volume', 'adjusted'])
)

alt.layer(bar, line).resolve_scale(
        y='independent').properties(
        title="Apple: Price and Volume Chart",
        width=600)
Out[916]:

The panic sell-off in March is very evident in this plot.


Since we have the data from a few stocks, it is a better idea to create functions to create the plot.

In [917]:
def create_base(stock_df):
    base = alt.Chart(stock_df[stock_df['date'] > '2020-01-01']).mark_line().encode(
        alt.X('date:T', title=""),
    )
    return base

def plot_line(stock_df, base, color='magenta', width=700, height=400, date_labels=None, 
         xlab="Date", y1_lab="Volume", y2_lab="Price", y1_domain=None, y2_domain=None, 
         label_angle=45): 

    if date_labels is None:
        date_labels = list(stock_df.date.unique())
    if y1_domain is None:
        y1_domain = (stock_df.volume.min(), stock_df.volume.max())
    
    if y2_domain is None:
        y2_domain = (stock_df.adjusted.min(), stock_df.adjusted.max())
        
    chart = base.mark_line(color=color).encode(
        alt.Y("adjusted:Q", title=y2_lab, scale=alt.Scale(domain=y2_domain, )),        
    ).properties(height=height, width=width)
    
    return chart


def plot_bar(stock_df, base, color='magenta', width=700, height=400, date_labels=None, 
         xlab="Date", y1_lab="Volume", y2_lab="Price", y1_domain=None, y2_domain=None, 
         label_angle=45): 

    if date_labels is None:
        date_labels = list(stock_df.date.unique())
    if y1_domain is None:
        y1_domain = (stock_df.volume.min(), stock_df.volume.max())
    
    if y2_domain is None:
        y2_domain = (stock_df.adjusted.min(), stock_df.adjusted.max())
        
    chart = base.mark_bar(opacity=0.7).encode(
        alt.Y('volume:Q', title=y1_lab, scale=alt.Scale(domain=y1_domain), axis=alt.Axis(format='s')),        
    ).properties(height=height, width=width)    
    
    return chart

Another Price + Volume Chart

Let us now create the price and volume chart, which is more common. We concatenate the plots using vconcat and adjust the heights. The bottom chart is interactive, meaning you can use the mouse to zoom in and out.

In [918]:
aapl_base = create_base(aapl_df)
aapl_line = plot_line(aapl_df, aapl_base, color="orange", xlab="", label_angle=0, width=700, height=200)
aapl_bar = plot_bar(aapl_df, aapl_base, label_angle=0, width=700, height=80)
alt.vconcat(aapl_line, aapl_bar.interactive()).properties(
        title="APPLE: 2020 Price Volume Chart"        
    )
Out[918]:

The zoom in/out is of litte use in this plot. so, let us focus on the selections. Selections allow us to highlight a certain part of the plot.

Cross Filtering

We are going to add brush selection to the price chart and change the plot colors based on user selection.

💡 You can use your mouse to select a part of the line chart to see it in action.

In [919]:
brush = alt.selection_interval(encodings=['x'])
color = alt.condition(brush,
                      # adding date as color feels like a hack, and that also causes 
                      # the blue gradient on the bar chart
                      alt.Color('date:T', legend=None),
                      alt.value('lightgray'))


upper = aapl_line.add_selection(brush)
lower = aapl_bar.encode(alt.Y('volume:Q'), color=color)

alt.vconcat(upper, lower).properties(
        title="APPLE: 2020 Price Volume Chart"        
    )
Out[919]:

More Datasources

We can now start to bring in other data sources such as covid-19 data (source: our world in data).

In [920]:
covid_cases = pd.read_csv("owid-covid-data.csv", parse_dates=['date'])
covid_cases.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 49668 entries, 0 to 49667
Data columns (total 41 columns):
iso_code                           49381 non-null object
continent                          49094 non-null object
location                           49668 non-null object
date                               49668 non-null datetime64[ns]
total_cases                        49032 non-null float64
new_cases                          48809 non-null float64
new_cases_smoothed                 48027 non-null float64
total_deaths                       49032 non-null float64
new_deaths                         48809 non-null float64
new_deaths_smoothed                48027 non-null float64
total_cases_per_million            48745 non-null float64
new_cases_per_million              48745 non-null float64
new_cases_smoothed_per_million     47962 non-null float64
total_deaths_per_million           48745 non-null float64
new_deaths_per_million             48745 non-null float64
new_deaths_smoothed_per_million    47962 non-null float64
new_tests                          18016 non-null float64
total_tests                        18435 non-null float64
total_tests_per_thousand           18435 non-null float64
new_tests_per_thousand             18016 non-null float64
new_tests_smoothed                 20379 non-null float64
new_tests_smoothed_per_thousand    20379 non-null float64
tests_per_case                     18772 non-null float64
positive_rate                      19232 non-null float64
tests_units                        21241 non-null object
stringency_index                   41070 non-null float64
population                         49381 non-null float64
population_density                 47109 non-null float64
median_age                         44260 non-null float64
aged_65_older                      43599 non-null float64
aged_70_older                      44030 non-null float64
gdp_per_capita                     43681 non-null float64
extreme_poverty                    29130 non-null float64
cardiovasc_death_rate              44247 non-null float64
diabetes_prevalence                45835 non-null float64
female_smokers                     34603 non-null float64
male_smokers                       34162 non-null float64
handwashing_facilities             20779 non-null float64
hospital_beds_per_thousand         39933 non-null float64
life_expectancy                    48754 non-null float64
human_development_index            42705 non-null float64
dtypes: datetime64[ns](1), float64(36), object(4)
memory usage: 15.5+ MB
In [921]:
covid_cases.head(1)
Out[921]:
iso_codecontinentlocationdatetotal_casesnew_casesnew_cases_smoothedtotal_deathsnew_deathsnew_deaths_smoothed...gdp_per_capitaextreme_povertycardiovasc_death_ratediabetes_prevalencefemale_smokersmale_smokershandwashing_facilitieshospital_beds_per_thousandlife_expectancyhuman_development_index
0ABWNorth AmericaAruba2020-03-132.02.0NaN0.00.0NaN...35973.781NaNNaN11.62NaNNaNNaNNaN76.29NaN

1 rows × 41 columns

In [922]:
covid_deaths_plot = alt.Chart(covid_cases.groupby(['continent', 'date']).agg(sum).reset_index()).mark_area(opacity=0.7, clip=True, color='red').encode(
    alt.X('date:T', scale=alt.Scale(domain=(aapl_df.date.min(), aapl_df.date.max()))),
    alt.Y('total_deaths:Q', title="Total Deaths", axis=alt.Axis(format='s')),
    color='continent:N'
).properties(
    title="Total deaths due to Covid-19",
    width=700
)
covid_deaths_plot
Out[922]: