Complete SHAP tutorial for model explanation Part 5. Python Example

Summer Hu
4 min readJan 2, 2021

--

Explore SHAP force plot, summary plot, dependency plot, feature importance plot to explain machine learning model.

NAZCA LINES
© Original: https://www.machutravelperu.com/blog/how-were-the-nazca-lines-made
  1. Part 1. Shapley Value
  2. Part 2. Shapley Value as Feature Contribution
  3. Part 3. KernelSHAP
  4. Part 4. TreeSHAP
  5. Part 5. Python Example

This part we will explore the following four common ways to explain machine learning model using SHAP:

  1. SHAP Individual and Collective Force Plot
  2. SHAP Summary Plot
  3. SHAP Feature Importance
  4. SHAP Dependence Plot

Please refer to Part. 1,2,3,4 for building up SHAP intuition if you are interested in knowing more details of SHAP.

Example Model

The example is using XGBRegressor to predict Boston Housing price, the source data is from Kaggle.

Firstly, we need install SHAP python library by the following command

pip install shap
or if in conda environment
conda install -c conda-forge shap

XGBRegressor model is as below

import shap
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from xgboost import XGBRegressor
from sklearn import metrics
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
shap.initjs()boston = load_boston()
X = pd.DataFrame(boston.data, columns=boston.feature_names)
y = boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=22)
# Reduce the size of data from original training data to estimate Shape values
X_train_summary = shap.kmeans(X_train, 10)
model = XGBRegressor(n_estimators=100)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print('RMSE:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))
print('R-Squared:', metrics.r2_score(y_test, y_pred))
RMSE: 3.5104295987738467
R-Square: 0.8531088933245674

1. SHAP Force Plot

From force plot, for any given model instance(or observation), we can analyse:

a. Why the instance’s prediction is different from model average prediction

b. How much each feature contribute to the difference.

tree_explainer = shap.TreeExplainer(model)
shap_values = tree_explainer.shap_values(X)
shap.force_plot(tree_explainer.expected_value, shap_values[10,:], X_test.iloc[0,:])

The above example targets on 10th instance in dataset X, and chooses using TreeExplainer to calculate the instance’s feature contributions. The base value is average prediction of all instances from dataset X , f(x) is the current instance prediction.

From above chart, we can see, on 10th instance, the blue features have negative contribution to prediction, red features have positive contribution, each feature contribution power is reflected by the feature width.

2. SHAP Summary Plot

Summary plot show each feature’s SHAP Value distribution for all X dataset instances.

tree_explainer = shap.TreeExplainer(model)
shap_values = tree_explainer.shap_values(X)
shap.summary_plot(shap_values, X)

a. The color differentiate feature value range scale(smallest or largest).

b. In each feature row, the various heights in the row represent instance density(distribution) according to its SHAP value.

c. Let’s analyse LSTAT feature, generally we can see larger LSTAT value mapping to smaller LSTAT SHAP value, vice versa.

3. SHAP Feature Importance

From above summary plot, for each feature, if we average absolute values of the SHAP values on all the training instances, then the average is the feature’s SHAP Feature Importance.

shap.summary_plot(shap_values, X, plot_type="bar")

From the above SHAP Feature Importance definition, we can see:

a. Feature Importance is a global aggregation measure on feature, it average all the instances to get feature importance.

b. SHAP is local instance level descriptor on feature, it only focus on analyse feature contributions for one instance.

4. SHAP Dependence Plot

The below is an example to plot feature LSTAT value vs. the SHAP value of LSTAT across all X instances.

shap.dependence_plot("LSTAT", shap_values, X, interaction_index=None)

We always use Partial Dependency Plot to check the linear relationship. In above example, LSTAT has negative linear relationship with its SHAP value, so LSTAT should has negative linear relationship with target variable.

Let’s check CRIM feature as below

shap.dependence_plot("CRIM", shap_values, X, interaction_index=None)

Obviously there is no linear relationship between CRIM and CRIM SHAP Value, so CRIM has no linear relationship with target variable.

Conclusion

In this part, we quickly go through the methods to explain the model feature contributions using SHAP. The example we used is based on TreeExplainer, and SHAP library provides other explainers as well, like KernelExplainer which is model-agnostic and can be used for any model, DeepExplainer which is used for deep learning models etc, so please refer to SHAP documentation for more details.

REFERENCES

  1. Interpretable Machine Learning: https://christophm.github.io/interpretable-ml-book/shap.html
  2. A Unified Approach to Interpreting Model Prediction: https://arxiv.org/abs/1705.07874
  3. Consistent Individualized Feature Attribution for Tree
    Ensembles: https://arxiv.org/abs/1802.03888
  4. SHAP Part 3: Tree SHAP: https://medium.com/analytics-vidhya/shap-part-3-tree-shap-3af9bcd7cd9b
  5. PyData Tel Aviv Meetup: SHAP Values for ML Explainability — Adi Watzman: https://www.youtube.com/watch?v=0yXtdkIL3Xk
  6. The Science Behind InterpretML- SHAP: https://www.youtube.com/watch?v=-taOhqkiuIo
  7. Game Theory (Stanford) — 7.3 — The Shapley Value : https://www.youtube.com/watch?v=P46RKjbO1nQ
  8. Understanding SHAP for Interpretable Machine Learning: https://medium.com/ai-in-plain-english/understanding-shap-for-interpretable-machine-learning-35e8639d03db
  9. Kernel SHAP:https://www.telesens.co/2020/09/17/kernel-shap/
  10. Understanding the SHAP interpretation method: Kernel SHAP:https://data4thought.com/kernel_shap.html

--

--