Explainable Machine Learning using SHAP

Models that we put in production need to be explainable: we need to understand how each feature impacts the overall predictions.

Machine Learning models are often treated as black boxes, meaning it is difficult to demonstrate how they reach their specific predictions. In order to facilitate their adoption by the stakeholders we need to provide an interpretation mechanism for the model’s features. In other words, Explainable Machine Learning techniques are needed to unravel some of these aspects. The higher the complexity of the model the harder it is to be interpreted (Figure 1). 

Figure 1: Interpretability versus performance trade-off given common ML algorithms. Image adjusted from AWS white paper: Interpretabolity vs Explainability. https://docs.aws.amazon.com/whitepapers/latest/model-explainability-aws-ai-ml/interpretability-versus-explainability.html

In this tutorial, I am using as an example a house pricing model previously developed to predict the house prices in California. This is a regression model and the LightGBM framework was used. LightGBM is a gradient boosting ensemble method that consists of multiple decision trees that are optimized in every iteration. It is considered as a black-box algorithm making its interpretation a challenge. The SHapley Additive exPlanation values (SHAP) is an agnostic-model method for feature importance and is suitable for tree-based complex algorithms. It is based on game theory and is used to increase transparency in black-box models and offers global as well as local interpretability of the machine learning model’s features.

The machine learning model developed, consists of 7 independent variables (features) and 1 target variable which is the price of the house to be predicted.

The model consists of 6 numerical features: median_income, latitude, longitude, population_per_household, rooms_per_bedroom, house_median_age and 1 categorical feature ocean_proximity of 5 categories  encoded as: op_INLAND, op_ISLAND, op_NEAR_BAY, op_NEAR_OCEAN, op_more_1H_ocean.

Global Interpretability 

Global interpretability refers to which variables are the most predictive for the model overall. 

Firstly, I used a summary plot to show the feature importance across all the predictions that the model made.  The summary plot depicts the average SHAP values per feature and the features are ranked from most to least important (top to bottom) (Figure 2). This plot helps to understand the overall impact of a feature to the predictions that the machine learning model makes.

Python
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.summary_plot(shap_values.values, X_test, plot_type="bar")

Figure 2: Summary plot showing the average SHAP values for the features in the machine learning model.

The top 3 most influential features for the model for predicting the price are the median income and the location of the house (latitude and longitude) which are very reasonable factors determining the price for a house.

Furthermore, I checked the relationship between the most important variable and the prediction (Figure 3). To that end, I used the Beeswarm plot from the SHAP package. Beeswarm plot, not only reveal the relative importance of features, but their actual relationships with the predicted outcome. It is useful when we want to examine how the underlying values of each feature relate to the model’s predictions.

Python
shap.plots.beeswarm(explainer(X_test))

Figure 3: Beeswarm plot, ranked by mean absolute SHAP value. This provides a rich overview of how the variables impact the model’s predictions across all of the data.

The horizontal axis represents the SHAP value, while the color bar on the right shows us if that observation has a higher or a lower raw value. Top feature is the ‘median_income’. We can see from the plot above there is a positive correlation between the ‘median_income’ and the SHAP value. More specifically, high values of this ‘median_income’ (higher income) have positive SHAP values which means positive impact towards the price prediction. In the same manner, lower ‘median_income’ denotes lower SHAP values and negative impact on the final price. This positive correlation is expected, as can be seen from the heatmap correlation plot earlier that there is a positive correlation between the target variable ‘median_house_value’ and the ‘median_income’.

Beeswarm plots are information-dense and provide a broad overview of SHAP values for many features at once. However, to further delve and understand the relationship between a feature’s values and the model’s predicted outcomes, it is necessary to examine dependence plots. In Figure 4, it can be seen that higher income leads to higher house price predictions.

SHAP values above the y=0 line lead to predictions of higher house prices whereas those below it are associated with lower house price predictions. The raw variable value at which the distribution of SHAP values across the y=0 line reveals the threshold at which the model switches from predicting lower to higher house prices. For the median_income this is approximately between 3.8 and 5 as marked by the red marks.

Python
shap.dependence_plot("median_income", shap_values.values, X_test, interaction_index=None)

Figure 4: Dependence plot for the most important feature of the model median_income determined by mean absolute SHAP value.

Local Interpretability

For a given data point and associated prediction, local interpretability aims to determine how each of the  model’s features explain a specific prediction made by the machine learning model. In other words, the SHAP method aims to explain the prediction of an instance/observation by computing the contribution of each feature to the prediction.

As an example, for three prediction instances (Table 1), the 3 features (variables) with the most impact on the prediction of default (expected prediction) were identified. Each cell represents a percentage % of how much the expected prediction was boosted or decreased towards the model’s prediction.

For instance the feature longitude decreases the expected prediction in the instance with id 14180 by almost 37% whereas increases the expected prediction in the instance with id 17963 by 11.5%. Moreover the feature population_per_household scored consistently as second most influential for the predictions taken as example whereas the median_income and ‘op_INLAND’ were in first and third place respectively. Table 1: Top 3 variables with most contribution. % percentage = (absolute feature contribution / expected prediction) * 100%

Overall SHAP is an excellent measure for improving the explainability of the model. However, like any other methodology it has its own set of strengths and weakness. For instance, although it can provide insights about the correlation of each feature with the target, it cannot and should not be used as a tool for causal inference. It is imperative that the methodology is used keeping the limitations in the mind and evaluating SHAP values with appropriate context.

Share this post:

Related Articles
Data Engineering in Azure: understand PDFs using LLMs
Organization Migration in Terraform Cloud
Data Builder Dan: Episode 1 – Metadata Mayhem

Interested to join our team?

We’re always looking for our next data builder. Checkout our careers page to see our current openings. Your voice powers our innovation at Data Build Company. Join our team, where your ideas are not just heard but championed, paving the way for future developments in data engineering.

Join the Data Build Company family!