This notebook shows how its possible to understand parts of a complex black box model. Black box models built from decision trees or neural networks may have great predictive power, but can be difficult to explain. Sometimes, it is necessary to explain how exactly the model is working in a particular situation. This might come from yourself, your management, or a regulatory body. This notebook uses the concept of a Local Interpretable Model-Agnostic Explanation (LIME), to explain a black box model.
Look at the image below. The wall here represents the surface of your complex black box model. What LIME does is allow us to create a “window” in a local area to gain insight into that part of the model. It provides a local explanation. This is useful if the model incorporates a sensitive categorization, e.g., such as men versus women. Or it could be important to understand a certain class of predictions, e.g., explaining why loans are being denied.Image Credit: https://commons.wikimedia.org/wiki/File:Holy_Trinity_Church,_Takeley_-_nave_north_small_window_and_blocked_window_at_east.jpg
Image Credit: https://www.oreilly.com/ideas/ideas-on-interpreting-machine-learning
The second goal of this notebook is to explain reason codes. Reason codes allow us to understand the factors driving a prediction. For example, if your complex black box model denies a loan applicant, it could be necessary to explain the basis of that denial. Or you could be trying to understand why your model misclassified a certain prediction. To explain these predictions, its useful to understand the features driving a certain prediction.
One way to calculate reason codes is using the coefficients from the local linear model. The notebooks walks through this method showing all the code to generate reason codes. Another way is to rely on the existing lime R package to generate reason codes. The notebook show how you would use that package to get reasons for a classification model. Here is an example of reason codes used to explain a mushroom’s classifier decision:
Image Credit: https://github.com/marcotcr/lime/
The code for building the surrogate models and creating reason codes is based on Patrick Hall’s python notebook.
The dataset comes from Savvas’s NFL scraping data tutorial, and is available at data.world. The data contains the career statistics for about 8000 NFL football players. Since I have a good understanding of American football, I could understand the relationships uncovered by the surrogate models and reason codes. If you don’t have a strong understanding of American football, go ahead and substitute another dataset in this notebook.
The datasets includes features on the college the player played at, the position they played, the round they were drafted, the team they played for, and their career statistics. The last column, named target (scroll to the right on the table), is the length of the player’s career in years. A football glossary can explain all the terms and abbreviations. Lets start with loading and viewing the data.
data <- read.csv("https://query.data.world/s/WJhKX-vyU-mRj0ZNVab1apjJBnpBWE", header=TRUE, stringsAsFactors=FALSE)
data <- na.omit(data) ##For convenience, we are removing data with NAs
data <- data %>% mutate (College = as.factor(College), Pos =as.factor(Pos), Tm=as.factor(Tm))
The first step is building a model for the length of a player’s career. I am using h2o to build a gradient boosted machine model, but feel free to use another algorithm/tool. This is the black box model, so it can be complex. The features for the model will include all the information we have about the player. The target is the duration of a player’s career (target column). For simplicity of illustrating the code, I am just using default values in building the models.
h2o.init()
data.h2o <- as.h2o(data)
splits <- h2o.splitFrame(data.h2o, 0.7, seed=1234)
train <- h2o.assign(splits[[1]], "train.hex")
valid <- h2o.assign(splits[[2]], "val.hex")
x = 2:20
y= 21
model.gbm <- h2o.gbm(x=x,y=y,
training_frame = train,
validation_frame = valid)
##Get predictions
h2o.no_progress()
preds <- h2o.predict(model.gbm,valid)
preds <- h2o.cbind(preds,valid)
preds <- as.data.frame(preds)
preds <- preds %>% arrange(target)
The ranked predictions plot below compares the global predictions from our ‘black box’ model to the actual predictions. Ideally, the predictions follow the actual values closely. For this model, the plot shows some variability in the predictions, but you see a general trend for the predictions that follows the actual values. Another metric that we look at below is the R squared. After all, the better fit of the global model, the better the model.
ggplot(data=preds,aes(x=as.numeric(row.names(preds)),y=predict)) +
geom_point(aes(color="predictions",alpha=.1),size=.5) +
stat_smooth(method = "lm", formula = y ~ x, size = 1) +
geom_line(aes(y=target,color="actual")) +
ggtitle("Ranked Predictions Plot of the Black Box Model") +
xlab("Players") + ylab("Career length")
The variable importance plot highlights the features that have the most impact on the length of a career. The top three features (college, position, and round drafted) apply to all positions and seem to make intuitive sense for assessing the future career of a player. However, for the QB position, many of the other features like Def_Int, Sk, Rec, Rec_TD do not come into play. These are statistics earned by other positions, such as defensive players or receivers. So while this model may perform well according to its R squared value, it doesn’t help us understand the features driving individual positions. This is where LIME comes into play for explaining parts of a model.
## [1] "R squared is: 0.759900363013113"
To use LIME, we need to choose a local area of the model (the window). In this example, I choose the QB position and the resulting dataset is shown below. I could have easily picked another feature or even part of the target (say for the top 10% longest careers). Other approaches include using clustering approaches to group related observations (i.e., kmeans) or just choosing nearby points around a particular point. For those looking for interesting research projects, the selection of a local area could use more investigation and empirical rigor.
local_frame <- preds %>% filter (Pos == 'QB')
Local dataset of QBs for building the surrogate model:
One issue when building a linear model is correlated features. Correlated features can throw off importance measurements. The below code chunk identifies correlated features.
#First we need to identify the correlated features and remove them
d <- Filter(is.numeric, data)
d_cor <- as.matrix(cor(d))
d_cor_melt <- arrange(melt(d_cor), -abs(value))
d_cor_melt <- filter(d_cor_melt, value > .8) %>% filter (Var1 != Var2) %>%
rename (correlation = value)
options(knitr.table.format = "html")
kable(d_cor_melt, caption = "Correlated features") %>% kable_styling (bootstrap_options = "striped", full_width = F, position = "left")
Var1 | Var2 | correlation |
---|---|---|
Yds | Cmp | 0.9980872 |
Cmp | Yds | 0.9980872 |
Att | Cmp | 0.9974013 |
Cmp | Att | 0.9974013 |
Yds | Att | 0.9969299 |
Att | Yds | 0.9969299 |
Rush_Yds | Rush_Att | 0.9936215 |
Rush_Att | Rush_Yds | 0.9936215 |
TD | Yds | 0.9889622 |
Yds | TD | 0.9889622 |
TD | Cmp | 0.9848143 |
Cmp | TD | 0.9848143 |
TD | Att | 0.9802279 |
Att | TD | 0.9802279 |
Rec_Yds | Rec | 0.9646876 |
Rec | Rec_Yds | 0.9646876 |
Int | Att | 0.9565515 |
Att | Int | 0.9565515 |
Int | Yds | 0.9452818 |
Yds | Int | 0.9452818 |
Int | Cmp | 0.9403071 |
Cmp | Int | 0.9403071 |
Rec_TD | Rec_Yds | 0.9343107 |
Rec_Yds | Rec_TD | 0.9343107 |
Rush_TD | Rush_Yds | 0.9319339 |
Rush_Yds | Rush_TD | 0.9319339 |
Rush_TD | Rush_Att | 0.9305012 |
Rush_Att | Rush_TD | 0.9305012 |
Int | TD | 0.9233499 |
TD | Int | 0.9233499 |
Rec_TD | Rec | 0.8739940 |
Rec | Rec_TD | 0.8739940 |
Yikes! The career statistics are strongly correlated. The best approach is to remove correlated features before building a linear model. In this case, the following features are removed: Rush_TD, Rec_TD, TD, Att, Cmp, Rush_Yds, Rec_Yds, Yds
The next step is building our surrogate local linear model. I am using the glm function in h2o which builds an elastic net generalized linear model. Its also possible to use other interpretable models that provide coefficients such as a generalized additive model (GAM).
x = c(3:7,12:13,16,19:21) #Skips over correlated features
y= 2
h2o.no_progress()
local_frame.h2o <- as.h2o(local_frame)
local_glm <- h2o.glm(x=x,y='predict',training_frame =local_frame.h2o, lambda_search = TRUE)
#Get predictions
pred_local <- h2o.predict(local_glm,local_frame.h2o)
pred_local <- as.data.frame(pred_local)
local_frame <- preds %>% filter (Pos == 'QB')
local_frame$predictlocal <- pred_local$predict
local_frame <- local_frame %>% arrange(predict)
The plot below compares the global predictions from our ‘black box’ model to the local linear model. Ideally, the local model follows the global model closely. However, since the local model is a linear model, it won’t be able to fully fit the complex non-linear black box global model. This plot provides some insight into how well the local model is fitting the global model in this particular area, i.e, quarterbacks in our case. In this case, the global model is represented by the red dots and the general trend of the local predictions is shown. Other metrics you can use to ensure the local model is capturing the local area is R squared. After all, if a local linear model is not a good fit, then we can not trust the interpretation of the local model.
ggplot(data=local_frame,aes(x=as.numeric(row.names(local_frame)),y=predictlocal)) +
geom_point(aes(color="linear"),size=1) +
stat_smooth(method = "lm", formula = y ~ x, size = 1) +
geom_point(aes(y=predict,color="global")) +
ggtitle("Ranked Predictions Plot of the Local Linear Model") +
xlab("Players") + ylab("Career length")
#r2 <- h2o.r2(local_glm)
#paste0("R squared is: ",r2)
With the local model built, we can use its coefficients to start explaining the local area. Ranking the coefficients provides a feature importance ranking for QBs. A positive sign means the coefficient is increasing the length of a career and vice versa.
imp <- h2o.varimp(local_glm)
imp <- imp %>% filter (coefficients>0) %>% mutate(coefficients=round(coefficients,3))
imp$names <- sub("College.","",imp$names)
imp$names <- sub("Tm.","",imp$names)
#h2o.varimp_plot(local_glm, num_of_features = NULL)
kable(imp, caption = "Local Feature Importance for QBs") %>% kable_styling (bootstrap_options = "striped", full_width = F, position = "left")
names | coefficients | sign |
---|---|---|
Int | 1.687 | POS |
Washington St. | 0.703 | POS |
Texas-El Paso | 0.676 | NEG |
RAI | 0.552 | POS |
Arizona St. | 0.548 | POS |
Rnd | 0.548 | NEG |
MIN | 0.524 | POS |
Utah St. | 0.521 | POS |
UCLA | 0.515 | POS |
STL | 0.481 | NEG |
Tulane | 0.447 | POS |
WAS | 0.440 | POS |
Tennessee | 0.440 | POS |
BUF | 0.440 | NEG |
SDG | 0.421 | POS |
Bowling Green | 0.386 | POS |
New Mexico | 0.353 | NEG |
Age | 0.328 | NEG |
Rush_Att | 0.321 | POS |
Arkansas | 0.318 | NEG |
RAM | 0.287 | POS |
PHI | 0.245 | NEG |
PIT | 0.184 | POS |
Yale | 0.156 | POS |
Rec | 0.133 | NEG |
NYJ | 0.119 | NEG |
Alabama | 0.106 | POS |
MIA | 0.047 | POS |
NWE | 0.044 | NEG |
Duke | 0.039 | NEG |
BAL | 0.029 | NEG |
BYU | 0.026 | POS |
Tkl | 0.014 | NEG |
CLE | 0.006 | POS |
San Jose St. | 0.006 | POS |
The results here differ from the global feature importance. The top features like interceptions and the round they were drafted make intuitive sense. The results for colleges also offer a lot of fodder for assessing quarterbacks. Quarterbacks looking for a long NFL career should consider Washington State or Arizona State and avoid Texas El-Paso (sorry Paydirt Pete).
Reason codes are the factors affecting a particular prediction. This calculation starts with the prediction for a particular player. The code below selects a player and shows their data. The fields include the actual length of their career (target), the global prediction (predict), and the local prediction in the QB area (localpredict).
row = 108 # select a row to describe #20
local_frame[row,]
## predict Player Tm Age College Pos Rnd Cmp Att Yds TD Int Rush_Att Rush_Yds Rush_TD Rec Rec_Yds Rec_TD Tkl Def_Int Sk target predictlocal
## 108 13.47019 Jim Plunkett NWE 23 Stanford QB 1 121.4375 231.3125 1617.625 10.25 12.375 20.1875 83.5625 0.875 0 0 0 0 0 0 16 12.27773
playername <- local_frame$Player[row]
The next step is multiplying the local feature importance coefficients against the actual values. The code below includes a few data wrangling steps.
df <- as.data.frame(t(local_frame[row,]))
df1 <- df %>% tibble::rownames_to_column() %>% #
mutate (names=rowname)
colnames(df1)[2] <- "player"
df1$player <- as.character(df1$player)
df1[3,'names'] <-df1[3,'player'] #Copy Team over to names
df1[5,'names'] <-df1[5,'player'] #Copy College over to names
df1 <- df1 %>%
left_join(imp,by='names') %>% #Join local feature importance by names
filter (!is.na(sign)) %>% #Remove non matches
mutate (player = as.numeric(player)) %>%
mutate (player = ifelse (is.na(player),1,player)) %>% #Account for characters
mutate (strength = player*coefficients) %>%
mutate (strength = ifelse(sign=="NEG",strength*-1,strength)) %>%
filter (round(strength,1)!=0)
The table below is used for determining the reason codes. Multiplying the actual player statistics against the coefficient of the linear model gives the strength for that feature. For example, if the career average by season for interceptions is 10 and the coefficient for interceptions is 1.7, then the strength of that reason code would be 17. This is then calculated for all the features and plotted below.
rowname | player | names | coefficients | sign | strength |
---|---|---|---|---|---|
Age | 23.0000 | Age | 0.328 | NEG | -7.544000 |
Rnd | 1.0000 | Rnd | 0.548 | NEG | -0.548000 |
Int | 12.3750 | Int | 1.687 | POS | 20.876625 |
Rush_Att | 20.1875 | Rush_Att | 0.321 | POS | 6.480188 |
Now we have reasons for an individual prediction! This allows us to understand the effects of different features on the prediction. In this case, interceptions is the strongest feature driving the length of a career. Try this with different players, for example, quarterback with a short career duration, say row 10. Try this with other local areas other than QB . . . look at a popular college team or the players with the shortest predicted career. The key is now we have insights at the level of predictions.
For those unwilling to roll their own reason codes, there are several packages available that can provide reason codes. I will walk through the lime package available in R ported over by Thomas Lin Pederson. This package is still early in its development, so it doesn’t cover regression models. The first step is turning the dataset into a classification problem. This is done by changing our target for a NFL career to greater than three years versus three or fewer years.
library(lime)
library(randomForest)
library(caret)
data <- data %>% dplyr::select (-Rush_TD,-Rec_TD,-TD,-Att,-Cmp,-Rush_Yds,-Rec_Yds,-Yds,-College)
data$target <- ifelse(data$target>3,1,0)
data$target <- as.factor(data$target)
The next step is building a model that is compatible with the lime package. I used the randomForest package. I also had to drop the college feature, because the randomForest package can’t handle 400 categories. This model serves as our global “black box” model.
smp_size <- floor(0.75 * nrow(data))
## set the seed to make your partition reproductible
set.seed(123)
train_ind <- sample(seq_len(nrow(data)), size = smp_size)
train <- data[train_ind,]
valid <- data[-train_ind, ]
train_lab <- train$target
train <- train %>% dplyr::select (-target)
##Model can take over an hour to run
#model <- train(train[,2:11], train_lab, method = 'rf')
#save(model, file = "RF_NFL_R.rda")
load(file = "RF_NFL.rda")
#Can a copy of RF_NFL.rda at https://www.dropbox.com/s/7phbsg13wv2b00j/RF_NFL_R.rda?dl=0
The next step is running the explain part of lime package. The explain part builds a local linear model and calculates out the reason codes. The local linear model is not built around a specific position, but instead built around each prediction. Here is the description from the authors:
An illustration of this process is given below. The original model's decision function is represented by the blue/pink background, and is clearly nonlinear. The bright red cross is the instance being explained (let's call it X). We sample perturbed instances around X, and weight them according to their proximity to X (weight here is represented by size). We get original model's prediction on these perturbed instances, and then learn a linear model (dashed line) that approximates the model well in the vicinity of X. Note that the explanation in this case is not faithful globally, but it is faithful locally around X.
explain <- lime(train[,2:11], model)
explanation <- explain(valid[201,2:11], n_labels = 1, n_features = 3)
valid[201,1]
## [1] "Jess Lewis"
plot_features(explanation)
The reason codes here show a strong probability for Class 0, which is a career less than three years. The reason codes make sense, the player was a late round draft pick, a linebacker, and 22 years old. Lets do one more example:
explanation <- explain(valid[205,2:11], n_labels = 1, n_features = 3)
valid[205,1]
## [1] "Jim Plunkett"
plot_features(explanation)
The reason codes show a good probability that the player is in Class 1, which is a career more than three years. The reason codes make sense, the player was an early draft pick and a quarterback. In this case, being 22 years old actually lowers the career length. These examples illustrate the insights reason provide into the factors driving a particular prediction.
Once you work through the code here, you should have a good understanding of how to use a linear surrogate model to get reason codes. Moving forward, try taking these ideas to get a window into your black box models. Here are some ideas to try:
- Improve the models by adding a grid search for hyperparameters
- Convert the NFL dataset into a classification problem and run your own reason codes
- Try different sorts of local areas and compare the results
- Try different sorts of interpretable algorithms for modeling the local area, e.g., GAMs
- Try this with another dataset
- Wrap reason codes in a shiny app and use this to explain a model
- Let me know if you found this useful
You can find a copy of this notebook in this repo on my github, inter-examples. For more about me, check out my website or find me on twitter.
Thanks to Will, Patrick, and Rob for their detailed comments.