This notebook shows how its possible to understand parts of a complex black box model. Black box models built from decision trees or neural networks may have great predictive power, but can be difficult to explain. Sometimes, it is necessary to explain how exactly the model is working in a particular situation. This might come from yourself, your management, or a regulatory body. This notebook uses the concept of a Local Interpretable Model-Agnostic Explanation (LIME), to explain a black box model.

Look at the image below. The wall here represents the surface of your complex black box model. What LIME does is allow us to create a “window” in a local area to gain insight into that part of the model. It provides a local explanation. This is useful if the model incorporates a sensitive categorization, e.g., such as men versus women. Or it could be important to understand a certain class of predictions, e.g., explaining why loans are being denied.

Image Credit: https://commons.wikimedia.org/wiki/File:Holy_Trinity_Church,_Takeley_-_nave_north_small_window_and_blocked_window_at_east.jpg

This notebook has two main goals. The first is to show how to build the windows or a local linear surrogate model, based on a complex global model. The first example starts with a complex black box global model that predicts career length of all NFL positions. To explain the predictions for quarterbacks, we build a window or local linear model around predictions for quarterbacks. The local model can then identify the features that most impact the career length of a quarterback. Here is a pictorial representation of this process from the article by Hall. The tables in this example use loans, while our example will use statistics related to American football.:

Image Credit: https://www.oreilly.com/ideas/ideas-on-interpreting-machine-learning

The second goal of this notebook is to explain reason codes. Reason codes allow us to understand the factors driving a prediction. For example, if your complex black box model denies a loan applicant, it could be necessary to explain the basis of that denial. Or you could be trying to understand why your model misclassified a certain prediction. To explain these predictions, its useful to understand the features driving a certain prediction.

One way to calculate reason codes is using the coefficients from the local linear model. The notebooks walks through this method showing all the code to generate reason codes. Another way is to rely on the existing lime R package to generate reason codes. The notebook show how you would use that package to get reasons for a classification model. Here is an example of reason codes used to explain a mushroom’s classifier decision:



Image Credit: https://github.com/marcotcr/lime/

The code for building the surrogate models and creating reason codes is based on Patrick Hall’s python notebook.

The NFL data

The dataset comes from Savvas’s NFL scraping data tutorial, and is available at data.world. The data contains the career statistics for about 8000 NFL football players. Since I have a good understanding of American football, I could understand the relationships uncovered by the surrogate models and reason codes. If you don’t have a strong understanding of American football, go ahead and substitute another dataset in this notebook.

The datasets includes features on the college the player played at, the position they played, the round they were drafted, the team they played for, and their career statistics. The last column, named target (scroll to the right on the table), is the length of the player’s career in years. A football glossary can explain all the terms and abbreviations. Lets start with loading and viewing the data.

data <- read.csv("https://query.data.world/s/WJhKX-vyU-mRj0ZNVab1apjJBnpBWE", header=TRUE, stringsAsFactors=FALSE)
data <- na.omit(data)  ##For convenience, we are removing data with NAs
data <- data %>% mutate (College = as.factor(College), Pos =as.factor(Pos), Tm=as.factor(Tm))





Modeling the length of a player’s career

The first step is building a model for the length of a player’s career. I am using h2o to build a gradient boosted machine model, but feel free to use another algorithm/tool. This is the black box model, so it can be complex. The features for the model will include all the information we have about the player. The target is the duration of a player’s career (target column). For simplicity of illustrating the code, I am just using default values in building the models.

h2o.init()
data.h2o <- as.h2o(data)
splits <- h2o.splitFrame(data.h2o, 0.7, seed=1234)
train  <- h2o.assign(splits[[1]], "train.hex")
valid  <- h2o.assign(splits[[2]], "val.hex")
x = 2:20
y= 21
model.gbm <- h2o.gbm(x=x,y=y,
                  training_frame = train,
                  validation_frame = valid)

##Get predictions
h2o.no_progress()
preds <- h2o.predict(model.gbm,valid)
preds <- h2o.cbind(preds,valid)
preds <- as.data.frame(preds) 
preds <- preds %>% arrange(target)

The ranked predictions plot below compares the global predictions from our ‘black box’ model to the actual predictions. Ideally, the predictions follow the actual values closely. For this model, the plot shows some variability in the predictions, but you see a general trend for the predictions that follows the actual values. Another metric that we look at below is the R squared. After all, the better fit of the global model, the better the model.

ggplot(data=preds,aes(x=as.numeric(row.names(preds)),y=predict)) + 
  geom_point(aes(color="predictions",alpha=.1),size=.5) + 
  stat_smooth(method = "lm", formula = y ~ x, size = 1) + 
  geom_line(aes(y=target,color="actual")) +
  ggtitle("Ranked Predictions Plot of the Black Box Model") +
  xlab("Players") + ylab("Career length")


The variable importance plot highlights the features that have the most impact on the length of a career. The top three features (college, position, and round drafted) apply to all positions and seem to make intuitive sense for assessing the future career of a player. However, for the QB position, many of the other features like Def_Int, Sk, Rec, Rec_TD do not come into play. These are statistics earned by other positions, such as defensive players or receivers. So while this model may perform well according to its R squared value, it doesn’t help us understand the features driving individual positions. This is where LIME comes into play for explaining parts of a model.

## [1] "R squared is: 0.759900363013113"

Building a window into your model

To use LIME, we need to choose a local area of the model (the window). In this example, I choose the QB position and the resulting dataset is shown below. I could have easily picked another feature or even part of the target (say for the top 10% longest careers). Other approaches include using clustering approaches to group related observations (i.e., kmeans) or just choosing nearby points around a particular point. For those looking for interesting research projects, the selection of a local area could use more investigation and empirical rigor.

local_frame <- preds %>% filter (Pos == 'QB')


Local dataset of QBs for building the surrogate model:


One issue when building a linear model is correlated features. Correlated features can throw off importance measurements. The below code chunk identifies correlated features.

#First we need to identify the correlated features and remove them
d <- Filter(is.numeric, data)
d_cor <- as.matrix(cor(d))
d_cor_melt <- arrange(melt(d_cor), -abs(value))
d_cor_melt <- filter(d_cor_melt, value > .8) %>% filter (Var1 != Var2) %>%
  rename (correlation = value)
options(knitr.table.format = "html") 
kable(d_cor_melt, caption = "Correlated features") %>% kable_styling (bootstrap_options = "striped", full_width = F, position = "left")
Correlated features
Var1 Var2 correlation
Yds Cmp 0.9980872
Cmp Yds 0.9980872
Att Cmp 0.9974013
Cmp Att 0.9974013
Yds Att 0.9969299
Att Yds 0.9969299
Rush_Yds Rush_Att 0.9936215
Rush_Att Rush_Yds 0.9936215
TD Yds 0.9889622
Yds TD 0.9889622
TD Cmp 0.9848143
Cmp TD 0.9848143
TD Att 0.9802279
Att TD 0.9802279
Rec_Yds Rec 0.9646876
Rec Rec_Yds 0.9646876
Int Att 0.9565515
Att Int 0.9565515
Int Yds 0.9452818
Yds Int 0.9452818
Int Cmp 0.9403071
Cmp Int 0.9403071
Rec_TD Rec_Yds 0.9343107
Rec_Yds Rec_TD 0.9343107
Rush_TD Rush_Yds 0.9319339
Rush_Yds Rush_TD 0.9319339
Rush_TD Rush_Att 0.9305012
Rush_Att Rush_TD 0.9305012
Int TD 0.9233499
TD Int 0.9233499
Rec_TD Rec 0.8739940
Rec Rec_TD 0.8739940

Yikes! The career statistics are strongly correlated. The best approach is to remove correlated features before building a linear model. In this case, the following features are removed: Rush_TD, Rec_TD, TD, Att, Cmp, Rush_Yds, Rec_Yds, Yds

The next step is building our surrogate local linear model. I am using the glm function in h2o which builds an elastic net generalized linear model. Its also possible to use other interpretable models that provide coefficients such as a generalized additive model (GAM).

x = c(3:7,12:13,16,19:21)  #Skips over correlated features
y= 2
h2o.no_progress()
local_frame.h2o <- as.h2o(local_frame)
local_glm <- h2o.glm(x=x,y='predict',training_frame =local_frame.h2o, lambda_search = TRUE)

#Get predictions
pred_local <- h2o.predict(local_glm,local_frame.h2o)
pred_local <- as.data.frame(pred_local)
local_frame <- preds %>% filter (Pos == 'QB')
local_frame$predictlocal <- pred_local$predict
local_frame <- local_frame %>% arrange(predict)


The plot below compares the global predictions from our ‘black box’ model to the local linear model. Ideally, the local model follows the global model closely. However, since the local model is a linear model, it won’t be able to fully fit the complex non-linear black box global model. This plot provides some insight into how well the local model is fitting the global model in this particular area, i.e, quarterbacks in our case. In this case, the global model is represented by the red dots and the general trend of the local predictions is shown. Other metrics you can use to ensure the local model is capturing the local area is R squared. After all, if a local linear model is not a good fit, then we can not trust the interpretation of the local model.

ggplot(data=local_frame,aes(x=as.numeric(row.names(local_frame)),y=predictlocal)) + 
  geom_point(aes(color="linear"),size=1) + 
  stat_smooth(method = "lm", formula = y ~ x, size = 1) + 
  geom_point(aes(y=predict,color="global")) +
  ggtitle("Ranked Predictions Plot of the Local Linear Model") +
  xlab("Players") + ylab("Career length")

#r2 <- h2o.r2(local_glm)
#paste0("R squared is: ",r2)

Local feature importance

With the local model built, we can use its coefficients to start explaining the local area. Ranking the coefficients provides a feature importance ranking for QBs. A positive sign means the coefficient is increasing the length of a career and vice versa.

imp <- h2o.varimp(local_glm)
imp <- imp %>% filter (coefficients>0) %>% mutate(coefficients=round(coefficients,3))
imp$names <- sub("College.","",imp$names)
imp$names <- sub("Tm.","",imp$names)
#h2o.varimp_plot(local_glm, num_of_features = NULL)
kable(imp, caption = "Local Feature Importance for QBs") %>% kable_styling (bootstrap_options = "striped", full_width = F, position = "left")
Local Feature Importance for QBs
names coefficients sign
Int 1.687 POS
Washington St. 0.703 POS
Texas-El Paso 0.676 NEG
RAI 0.552 POS
Arizona St. 0.548 POS
Rnd 0.548 NEG
MIN 0.524 POS
Utah St. 0.521 POS
UCLA 0.515 POS
STL 0.481 NEG
Tulane 0.447 POS
WAS 0.440 POS
Tennessee 0.440 POS
BUF 0.440 NEG
SDG 0.421 POS
Bowling Green 0.386 POS
New Mexico 0.353 NEG
Age 0.328 NEG
Rush_Att 0.321 POS
Arkansas 0.318 NEG
RAM 0.287 POS
PHI 0.245 NEG
PIT 0.184 POS
Yale 0.156 POS
Rec 0.133 NEG
NYJ 0.119 NEG
Alabama 0.106 POS
MIA 0.047 POS
NWE 0.044 NEG
Duke 0.039 NEG
BAL 0.029 NEG
BYU 0.026 POS
Tkl 0.014 NEG
CLE 0.006 POS
San Jose St. 0.006 POS

The results here differ from the global feature importance. The top features like interceptions and the round they were drafted make intuitive sense. The results for colleges also offer a lot of fodder for assessing quarterbacks. Quarterbacks looking for a long NFL career should consider Washington State or Arizona State and avoid Texas El-Paso (sorry Paydirt Pete).

Reason codes

Reason codes are the factors affecting a particular prediction. This calculation starts with the prediction for a particular player. The code below selects a player and shows their data. The fields include the actual length of their career (target), the global prediction (predict), and the local prediction in the QB area (localpredict).

row = 108 # select a row to describe #20
local_frame[row,]
##      predict       Player  Tm Age  College Pos Rnd      Cmp      Att      Yds    TD    Int Rush_Att Rush_Yds Rush_TD Rec Rec_Yds Rec_TD Tkl Def_Int Sk target predictlocal
## 108 13.47019 Jim Plunkett NWE  23 Stanford  QB   1 121.4375 231.3125 1617.625 10.25 12.375  20.1875  83.5625   0.875   0       0      0   0       0  0     16     12.27773
playername <- local_frame$Player[row]

The next step is multiplying the local feature importance coefficients against the actual values. The code below includes a few data wrangling steps.

df <- as.data.frame(t(local_frame[row,]))
df1 <- df %>% tibble::rownames_to_column() %>%  #
          mutate (names=rowname)
colnames(df1)[2] <- "player"
df1$player <- as.character(df1$player)
df1[3,'names'] <-df1[3,'player'] #Copy Team over to names
df1[5,'names'] <-df1[5,'player'] #Copy College over to names
df1 <- df1 %>% 
        left_join(imp,by='names') %>% #Join local feature importance by names
        filter (!is.na(sign)) %>%  #Remove non matches
        mutate (player = as.numeric(player)) %>% 
        mutate (player = ifelse (is.na(player),1,player)) %>% #Account for characters 
        mutate (strength = player*coefficients) %>% 
        mutate (strength = ifelse(sign=="NEG",strength*-1,strength)) %>% 
        filter (round(strength,1)!=0)

The table below is used for determining the reason codes. Multiplying the actual player statistics against the coefficient of the linear model gives the strength for that feature. For example, if the career average by season for interceptions is 10 and the coefficient for interceptions is 1.7, then the strength of that reason code would be 17. This is then calculated for all the features and plotted below.

Reason Codes Table
rowname player names coefficients sign strength
Age 23.0000 Age 0.328 NEG -7.544000
Rnd 1.0000 Rnd 0.548 NEG -0.548000
Int 12.3750 Int 1.687 POS 20.876625
Rush_Att 20.1875 Rush_Att 0.321 POS 6.480188

Now we have reasons for an individual prediction! This allows us to understand the effects of different features on the prediction. In this case, interceptions is the strongest feature driving the length of a career. Try this with different players, for example, quarterback with a short career duration, say row 10. Try this with other local areas other than QB . . . look at a popular college team or the players with the shortest predicted career. The key is now we have insights at the level of predictions.

The LIME package in R

For those unwilling to roll their own reason codes, there are several packages available that can provide reason codes. I will walk through the lime package available in R ported over by Thomas Lin Pederson. This package is still early in its development, so it doesn’t cover regression models. The first step is turning the dataset into a classification problem. This is done by changing our target for a NFL career to greater than three years versus three or fewer years.

library(lime)
library(randomForest)
library(caret)
data <- data %>%  dplyr::select (-Rush_TD,-Rec_TD,-TD,-Att,-Cmp,-Rush_Yds,-Rec_Yds,-Yds,-College)
data$target <- ifelse(data$target>3,1,0)  
data$target <- as.factor(data$target)


The next step is building a model that is compatible with the lime package. I used the randomForest package. I also had to drop the college feature, because the randomForest package can’t handle 400 categories. This model serves as our global “black box” model.

smp_size <- floor(0.75 * nrow(data))
## set the seed to make your partition reproductible
set.seed(123)
train_ind <- sample(seq_len(nrow(data)), size = smp_size)
train <- data[train_ind,]
valid <- data[-train_ind, ]

train_lab <- train$target
train <- train %>% dplyr::select (-target)

##Model can take over an hour to run
#model <- train(train[,2:11], train_lab, method = 'rf')
#save(model, file = "RF_NFL_R.rda")
load(file = "RF_NFL.rda")  
#Can a copy of RF_NFL.rda at https://www.dropbox.com/s/7phbsg13wv2b00j/RF_NFL_R.rda?dl=0


The next step is running the explain part of lime package. The explain part builds a local linear model and calculates out the reason codes. The local linear model is not built around a specific position, but instead built around each prediction. Here is the description from the authors:

An illustration of this process is given below. The original model's decision function is represented by the blue/pink background, and is clearly nonlinear. The bright red cross is the instance being explained (let's call it X).  We sample perturbed instances around X, and weight them according to their proximity to X (weight here is represented by size).  We get original model's prediction on these perturbed instances, and then learn a linear model (dashed line) that approximates the model well in the vicinity of X.  Note that the explanation in this case is not faithful globally, but it is faithful locally around X.


pic

explain <- lime(train[,2:11], model)
explanation <- explain(valid[201,2:11], n_labels = 1, n_features = 3)
valid[201,1]
## [1] "Jess Lewis"
plot_features(explanation)


The reason codes here show a strong probability for Class 0, which is a career less than three years. The reason codes make sense, the player was a late round draft pick, a linebacker, and 22 years old. Lets do one more example:

explanation <- explain(valid[205,2:11], n_labels = 1, n_features = 3)
valid[205,1]
## [1] "Jim Plunkett"
plot_features(explanation)


The reason codes show a good probability that the player is in Class 1, which is a career more than three years. The reason codes make sense, the player was an early draft pick and a quarterback. In this case, being 22 years old actually lowers the career length. These examples illustrate the insights reason provide into the factors driving a particular prediction.

Next steps

Once you work through the code here, you should have a good understanding of how to use a linear surrogate model to get reason codes. Moving forward, try taking these ideas to get a window into your black box models. Here are some ideas to try:
- Improve the models by adding a grid search for hyperparameters
- Convert the NFL dataset into a classification problem and run your own reason codes
- Try different sorts of local areas and compare the results
- Try different sorts of interpretable algorithms for modeling the local area, e.g., GAMs
- Try this with another dataset
- Wrap reason codes in a shiny app and use this to explain a model
- Let me know if you found this useful

Further Reading


Credits

You can find a copy of this notebook in this repo on my github, inter-examples. For more about me, check out my website or find me on twitter.

Thanks to Will, Patrick, and Rob for their detailed comments.