Eyes, JAPAN
SHAP: Understanding an NLP Model Through Visualization
Cherubin Mugisha
Introduction
Making sense of Machine Learning models has always been challenging despite its paramount importance. Exploratory and explanatory algorithms have been the keys to guiding model architectures and increasing the credibility of the models, that have always been seen as black boxes.
Visualizing a model is a way to answer questions such as why the model works, how it works and what it performs. In this blog, I will talk about one of the main ML visualization tools, SHAP. Using natural language processing (NLP) models, I will provide a demo walk-through with a simple attention visualization to highlight the inspiration of interpretability from computer vision to NLP. This example will demonstrate how those 2 applications are different and will discuss how to interpret their outputs and graphics, but also the specific limitations and shortcomings of each approach
SHAP
SHapley Additive exPlanations (Scott M. Lundberg, Su-In Lee) is a method based on the game theory to explain individual predictions by computing the contribution of each feature to the prediction. By computing the Shapley values, it gives an overview of how each individual features contribute to the final prediction probability. Although this might be more intuitive for tabular data, and image processing because of their relative low dimensionalities, it might not make sense when it comes to interpreting outputs for an NLP task such as a classification.
For example, in tabular data, a feature value might have a higher correlation with the model output and that should make sense for the user. The same as for images, pixels can be grouped into superpixels and their distribution could match with what we see on the image especially if that is what we expect as a discriminative feature.
All the code and outputs are provided at the end of this blood. Please refer to it while reading and feel free to run it in Google Colab notebook
Explaining a transformers-based model
-
NLP model for classification
While the most common way to understand a linear model is to examine the coefficients learned for each feature, it is a little complex for transformers because of their attention mechanism. A model such as BERT uses embeddings that depend on the tokenization of the input sequence. Thus, understanding the model’s mechanism comes to understanding their attention. However, researchers agree that the most popular NLP models are not the most explainable because the high dimensionality of their architectures can’t match our understanding of the relationship between words and tokens.
For example, BERT’s attention heads may show patterns between delimiter tokens and global tokens, specific tokens, or even the whole sentence. These patterns constantly vary from one layer to another and may not exhibit similar behaviors with seemingly similar inputs. Although this might seem not interpretable by humans, I will demonstrate that explainables such as SHAP or ECCO have the potential to provide substantial syntactic information captured from an NLP model’s attention.
The following code demonstrates how SHAP can be used to visualize and interpret what a complex transformers-based model sees when applied to a classification task.
#import the necessary libraries
import pandas as pd
import shap
import sklearn
Load a pretrained sentiment classifier model from the transformers library as well as the data from IMDB review dataset.
import transformers
import datasets
import torch
import numpy as np
import scipy as sp
# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
).cuda()
# define a prediction function
def f(x):
tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=500, truncation=True) for v in x]).cuda()
outputs = model(tv)[0].detach().cpu().numpy()
scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
val = sp.special.logit(scores[:,1]) # use one vs rest logit units
return val
# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)
# explain the model's predictions on IMDB reviews
imdb_train = datasets.load_dataset("imdb")["train"]
shap_values = explainer(imdb_train[:30], fixed_context=1, batch_size=2)
From a random review sample, we will visualize its sentiment, as predicted by the model and which part of the sequence participated the most in the prediction probabilities.
# plot a sentence's explanation
shap.plots.text(shap_values[27])
This visualization uses the probability scores of the model’s prediction. By taking into account the underlying patterns inside the layers, it uses the neuron activations values, more precisely the non-negative matrix factorization.
2. NLP Model for Translation
Imagining what happens inside a sequence-to-sequence model could be described easier than a classifier. Our intuition for a translation could even show us more precisely which token or words have been mistranslated.
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import shap
import torch
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr").cuda()
s=["In my family, we are six: my father, my mother, my elder sister, my younger brother and sister"]
explainer = shap.Explainer(model,tokenizer)
shap_values = explainer(s)
By hovering over the translated text, you will notice that those neuron activations carry embeddings scores. The non-latent embedding scores can be translated into embedding space where word similarity can easily be visualized. For example, in the embedding space, for “father”, “brother” will be what “sister” is for “mother”. This can be visualized even through a sentiment classification where each word is represented by its embedding score.
Github code
Conclusion
Machine learning explanatory tools can be very informative. However, due to the high dimensionality of current language models, we should not always expect to be able to visualize and make sense in our two or 3 dimension space what is computed by multi-million (billions) parameters transformer-based models.