Explaining machine learning models

Packages

library(tidyverse)
library(caret)
library(magrittr)
library(DALEX)

Overview

This blog will cover DALEX explainers. These are very useful when we need to validate a model or explain why a model made the prediction it made on an observation basis.

The data

To show our explainers in action we will use the apartments dataset that ships along with the dalex package:

test_data <- DALEX::apartments

test_data %>% head
##   m2.price construction.year surface floor no.rooms    district
## 1     5897              1953      25     3        1 Srodmiescie
## 2     1818              1992     143     9        5     Bielany
## 3     3643              1937      56     1        2       Praga
## 4     3517              1995      93     7        3      Ochota
## 5     3013              1992     144     6        5     Mokotow
## 6     5795              1926      61     6        2 Srodmiescie
test_data %>% dim
## [1] 1000    6

The problem

Say we want to predict the m2.price of the appartment, but this prediction needs to be validated for different sized properties. In this case it’s not good enough to measure only a global fit. If for example the model fails to predict properly the value of expensive properties due to sample size and variability it could mean a large amount of risk if the model was deployed in production.

Benchmark many models with caret

To start off we want to throw a bunch of models at the problem to see what we are dealing with in terms of predictive power. Since DALEX is model agnostic we can simply leverage our caret package:

Set crossvalidation parameters

trControl <- trainControl(method = "cv",number = 4)

Build model data framework

train_data <- 
    tibble(model = c("rf","gbm","lm","glm","ridge"),data = list(apartments))

Train models

model_frame <- train_data %>%
    mutate(caret_models = map2(.x = model,.y=data,~train(form = m2.price~., data = .y,method = .x,trControl = trControl)))

Let’s pull out the root mean squared error from the models:

 model_frame %<>%
    mutate(RMSE = caret_models %>% map_dbl(~.x %>% pluck("results","RMSE") %>% min))

model_frame
## # A tibble: 5 x 4
##   model data                     caret_models  RMSE
##   <chr> <list>                   <list>       <dbl>
## 1 rf    <data.frame [1,000 x 6]> <S3: train>  174. 
## 2 gbm   <data.frame [1,000 x 6]> <S3: train>   85.8
## 3 lm    <data.frame [1,000 x 6]> <S3: train>  283. 
## 4 glm   <data.frame [1,000 x 6]> <S3: train>  284. 
## 5 ridge <data.frame [1,000 x 6]> <S3: train>  284.
model_frame %>% 
    ggplot()+
    geom_bar(aes(x = model,y=RMSE,fill=model),stat="identity")

We can see that the random forest and gradient boosted models perform much better out of the box as expected.

Visualize the residuals

Before we throw our new explainers at the models let’s quickly look at these residuals.

We can plot the distribution of residuals to get a better idea of what’s going on, but first we need to pull these out of the models:

model_frame %<>%
    mutate(residuals = map2(data,caret_models,~predict(object = .y,.x)-.x %>% pull("m2.price")))

Visualize the distribution:

model_frame %>% 
    unnest(residuals) %>% 
    ggplot()+
    geom_histogram(aes(x=residuals,fill=model))+
    facet_wrap(~model)+
    ggtitle("Distribution of errors by model")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Let’s visualize the predicted vs. actuals:

model_frame %<>% 
    mutate(predicted = map2(data,caret_models,~predict(object = .y,.x))) %>% 
    mutate(actuals = data %>% map(~.x %>% pull("m2.price")))
model_frame %>% 
    unnest(predicted,actuals) %>% 
    ggplot(aes(x=actuals %>% log,y=predicted %>% log,col=model))+
    geom_point()+
    geom_abline(slope = 1,intercept = 0)+
    facet_wrap(~model)+
    ggtitle("Actuals vs predicted on log scale")

We notice that the linear models tend to underestimate the value of the appartment more often than over estimating it.

Introducing DALEX explainers!

There’s a lot we can do with packages like yardstick and broom to pull out performance metrics and make plots. But this may still feel very very clumsy and amateurish. How can we quickly and easily produce more professional looking reports to show prospective clients and stakeholders?

The answer might be either DALEX or LIME. A new package that works with DALEX explainers is modelDown. This package allows us to create a static webpage to show off the performance of our models and also crutinize them.

So how does it work?

The DALEX::explainer needs the following:
- a model object
- data
- actuals (to be predicted)
- function that can get the predicted values

I quickly map over all my models to create an explainer for each of our models:

model_frame %<>% 
    mutate(explainers = pmap(list(model,data,caret_models,actuals),~explain(model = ..3,data = ..2,y = ..4,label = ..1,predict_function = predict)))

Since we really only care about the good models I can filter the rest out:

good_models <- 
    model_frame %>% 
    filter(model %in% c("rf","gbm"))

The modelDown::modelDown() function wants you to supply each explainer as an input in the dots(…) argument. To do this easily I just use the purrr:invoke function to pull them out of our table:

invoke(modelDown::modelDown,good_models$explainers)

This will produce the website with some explainers from the DALEX package.

Model performance

The first tab produces our risidual distribution plots for all provided models. My only criticism here is that they only show positive errors. This might hide some information about bias.

Variable Importance

The next tab shows us how important certain variables in the data were to the model. This is very standard output that we would normally produce anyway. But it’s nice to have it in here so we don’t have to worry about doing that ourselves.

Variable response

This is where things get juicy!

For each variable in the data the model will show us how the prediction function responds to changes in said variable.
This is fantastic since we can validate the assumptions made by the model. For example; is it reasonable that very old and very new houses should respond in this way?

For factor level variables the explainer does something even more awesome! The explainer will try to figure out how each level affects the response of the model (for each model). In the screenshot we can see 3 distinct district groups that affect the price of the property. This gives us immediate insight in the data without having to explore these sub levels ourselves!

Prediction breakdown

This page uses the function DALEX::prediction_breakdown() on the observation with the largest residual. This is really cool because we can see which variable would have pushed the price up and which would have done the opposite. In this example we see that the specific district completely over emphasized the price of this appartment because the district generally has very expensive appartments.

This DALEX funcion is very useful if we find observations that we need to explain. For example if we have trained and vetted a model that is currently running in production, we may deliver prediction to users via an app. These users may benefit from a breakdown to better sense check the predictions coming through.