Blog

The XAI Problem in Machine Learning

Alberto Rivera Martínez
Alberto Rivera Martínez
Machine Learning Engineer

AI algorithms can be difficult to understand, but XAI and Machine Learning interpretability are helping to build trust in the future of AI.

January 20, 2022

10 min read

The XAI Problem in Machine Learning

Explainable AI and Machine Learning interpretability: one of the biggest challenges for AI


“The world, as we know it, is living through a period of extraordinary change, and a big driver of this change is Artificial Intelligence (AI) and Machine Learning (ML).” -  Alexander Hagerup, CEO of Vic.ai.


Artificial intelligence, and more concretely ML (a subfield of AI), is one of the most sought-after technologies of the era. It is used across multiple industries and business processes. ML focuses on automatically learning and improving from experience without being explicitly programmed.


Key Takeaways:

  • Explainable AI is vital for us to understand why certain predictions or decisions have been made by algorithms
  • Neural networks learn complex patterns that humans struggle to visualize
  • AI has many subfields, such as machine learning, deep learning, and neural networks

In a high-level way, a ML model / algorithm receives data as input, or what we call experience,  and learns some complex patterns from this data. Afterwards, this ML model will be able to solve specific tasks based on what it has learned from the data.

Deep Learning (DL) is now the “hottest” area of ML achieving amazing results. For example, both Computer Vision (CV) and Natural Language Processing (NLP), have been used across multiple industries ranging from healthcare, to transportation, and even accounting.  They have taken great leaps in recent years, and have even been able to surpass humans in some specific tasks thanks to DL models.

However, generally, the better the algorithms, the less interpretable they are. For example, a linear regression will not achieve results as good as Neural Networks (used in DL), but it will be more interpretable and explainable.




In today’s article, we are going to discuss what Explainable AI (XAI) and ML interpretability mean, focusing on DL and applying an example to Computer Vision (CV).


Introduction to the XAI problem in ML models

ML models focus on giving an answer or output (with an associated probability of success or failure) to some values or attributes that are received as input. However, ML models don’t focus on explaining the reason for that output.

Maybe this can be something abstract for those who haven’t worked with ML models before. Don’t worry, let’s use an example. Suppose we want to predict housing prices (a regression task). To do this, we have trained a model with real data about houses. Each of them can be described with a set of features (attributes) such as the number of rooms, number of bathrooms, square meters, geographical location, type of house and price. Then, the ML model is able to associate housing features with its price. By this way, when we introduce a new house with its features as input, the ML model is able to predict its price.

Now suppose that Alice wants to know how much her house costs because she is thinking of selling it, and moving to another city. Alice then uses the ML model to predict the price of her house. After obtaining the output of the ML model, Alice asks herself the following questions:

  • Why is this the price of my house?
  • Why is it so low?
  • What attribute is the most important for price prediction?
  • Can I focus on improving a specific feature of my house and thus, increase its price?






Unfortunately, by itself, this model is not able to explain why the predicted price of Alice’s house is low,and how each input attribute affects the output of the model.


What is ML interpretability and why care about it?


Interpretability is the degree to which a human can understand the cause of a decision. The higher the interpretability of a ML model, the easier it is for someone to comprehend why certain decisions or predictions have been made. A model is more interpretable than another model if its decisions are easier for a human to comprehend than decisions from the other model.

And why do we need interpretability in ML models? These are some reasons:

  • It can aid in trust. As humans, we may be reluctant to rely on ML models for critical tasks such as medical diagnosis or autonomous driving, unless we know “how they work”.
  • Safety. There is almost always some shift in distributions between model training and deployment. Interpretability can help us to debug models (why the model is failing in certain samples and what features it is learning), detect shifts between training and production-real data, and detect bias in predictions such as racial bias, gender bias or medical diagnosis bias.
  • Contestability is to know why a specific decision has been taken by the ML model. As we delegate more decision-making to ML models, it becomes important for the people to appeal these decisions. This is what Alice needed in our previous example with the housing prices.

Another interesting reason for ML interpretability is that we can also learn from AI if we know how they make their decisions. This is not possible in all fields, because some of them have to deal with  large amounts of data. However, an example of this can be applied in medical diagnosis based on images, like X-Rays. Some Deep Learning models are yielding amazing results in that area. Why can’t we learn from AI and ML looking at what features they use to make their predictions?


Interpretability in Deep Learning models


As I have mentioned at the beginning of this article, DL is the “hottest” area of ML, being now the main area of research and development in AI. DL is framed within Representation Learning (RL) and Representation Learning is framed within ML. While ML describes the idea to extract knowledge directly from the data, RL aims at learning a suitable representation of the data for the given task. For example, finding the most important features automatically, which might also be combinations of original data attributes. Neural networks are the most important methods for this approach. An instance of this would be shallow autoencoders that are trained to reproduce the input data as its (potentially compressed) output by finding the relevant features representing the data. 

Before we continue, let’s define neural networks. According to MIT Management:

Neural networks are a commonly used, specific class of machine learning algorithms. Artificial neural networks are modeled on the human brain, in which thousands or millions of processing nodes are interconnected and organized into layers.


DL uses neural networks with a large number of hidden layers. The first layers might be specialized to find simple representations of basic features of the data (such as edges or corners in Computer Vision), which are then combined by subsequent layers to gain more complex representations.

DL models generally produce amazing results, much better than other ML traditional methods, when we have sufficient sized datasets. However, they are often seen as “black boxes”.

For example, Linear Regression models, a well-known and traditional ML method, are very easily interpretable. They predict the target as a weighted sum of the input features. Learned relationships are linear and can be written (for a single instance) as follows, where β values are the coefficients or weights of each input feature:

y =  β0 +  β1 * x1 + … +  βp * xp + ε

As it can be seen in the formula above, the “importance” of each feature is reflected in the form of weights, being able to interpret what features are more important (those that have the highest weights).

In the case of DL, models can have millions of parameters but we don’t know what represents each parameter. So, they are also frequently seen as “black boxes.”

Interpretability methods, applied to DL, focus on explaining these “black boxes,” showing which features are more important from the input for a specific output, or what is learning a specific layer of the neural network.

In this article, we are going to apply 3 interpretability methods to a pre-trained Convolutional Neural Network (CNN) called VGG16. The dataset used to train this model is ImageNet, which has over 14 million images belonging to 1000 classes. For this, we are going to use Keras to load the model, and then Tensorflow to implement those 3 interpretability methods.

These 3 interpretability methods are based on gradient ascent in the input space (in this case the input image) in order to find a visual pattern that a filter, a class activation or any other function is maximally responsive to. In this example, we will visualize first, filters that the CNN has learned and it’s application to predict the image class. Then, we’ll also visualize the synthetic images that are maximally responsive to each output class of the model. Finally, we’ll implement a method called Grad-CAM to see, with a heatmap, what parts of the input image are more important for the neural network to predict its class. For this, go to my github, download / clone the repository and open the notebook to see all the process. You can also execute it, change some parameters, change input images, etc.


Visualizing filters with synthetic images

Before going deeper, it’s important to mention how a CNN works . A CNN is composed of a set of convolutional layers where each one learns a specific set of filters. These filters are able to detect or extract features from the input image. First layers extract low-level features, such as edges; corners, colors or textures, and deeper layers learn higher-level and semantic features using \ low-level features, such as body parts or objects. After these convolutional layers, the neural network has some dense layers that act as a classifier (in a classification task). This classifier uses the extracted features by the convolutional layers to classify the input into a specific class.

A way to visualize what a Neural Network is learning is to visualize what filters are learning. For this, I have used gradient ascent in the input image, applying a specific objective function. Learning process of a neural network is based on gradient descent in the weights space. For this, we set up a cost function that measures the error the neural network is making with a subset of samples. When a training batch (subset of training samples) is introduced in the neural network, it generates the output for all these training samples, also calculating the value of the cost function (the error). This step is called feedforward. Then, gradient descent in the weights space is applied using the value of the cost function. This process consists of calculating the gradient, using backward propagation, which is the most used method and very efficient. Backward propagation is based on the derivatives chain rule. This gradient shows the opposite direction in which to move in order to reduce the value of the cost function. When calculated, neural network parameters / weights are updated using the opposite direction of it.

explainable AI - gradient descent


How to Visualize Convolutional Neural Networks with Gradient Ascent


So, if we are applying gradient descent in the parameters space to minimize the cost function, why can’t we flip this around? Instead of gradient descent, we can apply gradient ascent in the input space to maximize a specific objective function to get the input that is maximally responsive to it.

How to Visualize Convolutional Neural Networks with Gradient Ascent

Let’s test it out. I have specified an objective function, the output (activations) of a specific filter in a specific layer. By applying gradient ascent in the input image (an initial random image), I get the image that is maximally responsive to that objective function (or filter). 


In the image below, you can see the results. We can see the 64 learned filters by the neural network in the second layer. As I said before, first layers learn low-level features. As it can be seen, these low level filters are colors, lines, textures and directions. 

A picture containing chartDescription automatically generated


Important note: I am explaining algorithms and methods used in a high-level way, without getting into the weeds. So, if you are more interested in the implementation you can go to my github and see the notebook and code for this. For example, I am applying some preprocessing, post processing techniques to the images, normalizing the gradient, etc.


Visualizing classes with synthetic images

In the same way, we can also visualize the images that are maximally responsive to each class, modifying the objective function. In this case, we are maximizing class logits. An example can be seen in the image below. This image is maximally responsive to the “ostrich” class. If you look at the image, you will see some ostriches or some of their features there. 

A picture containing textDescription automatically generated

You might be wondering why the computer is showing a psychedelic, blurry image of ostriches. This is because a neural network learns complex patterns that we cannot visualize very well. There are some methods that transform those images into more visualizable images for humans.


Visualizing at what features a specific layer is paying more attention to classify a specific input image


I’m aware the last method is not easy to visualize for humans. So now, let’s move on to another method that allows us to see at what features a specific layer is paying more attention to, in order to classify the input image. This method is called GradCAM. 


In this case, instead of using a random initial image, we are using an image with a goldfinch. With a heatmap, we’ll see what parts of the image are more important for a specific layer of the neural network. The initial image, with the goldfinch can be seen below.

Gold finch - visualizing neral networks

Now, applying GradCAM to that image and visualizing the layer 17th (last convolutional layer), we can see that it’s focusing on the head and wings. So, this layer is extracting features from the goldfinch’s head and wings.

GradCam heatmap for convolutional layer of the neural network

We can also see and explore other layers. For example, the layer 16th in the image below is focusing on a smaller part of the goldfinch’s head.

GradCam heatmap fir a specific layer of the neural network

A CNN uses features from all layers to finally classify the input image into a specific class. So all these visualized features are being used by the CNN to classify the goldfinch image. The classification confidence for this sample is 96,25 %.


Visualizing the most important features used by the model to classify the input image

Finally, we can also visualize what are the most important features that the neural network is using to classify a specific sample. In the previous example, we were applying GradCAM to some specific convolutional layers, but what happens if we apply GradCAM to all convolutional layers?

By doing this, we can obtain the most important features that the neural network is using to classify the goldfinch. In the 3 images below, we can see the results. First image (starting on the left) shows us the initial image with the heatmap overlapping. The second one shows us only the heatmap. And finally, the third image is applying a binary mask to select / visualize the most important pixels of the image.


GradCam neural networks:  initial image with the heatmap overlapping
GradCam heatmap
Gradcam - applying a binary mask to select / visualize the most important pixels of the image


As you can see, (deep) neural networks and, in general, ML models are not “black boxes”. We can explain them and know what they are learning, why they are failing (debugging), and why they have taken a specific decision. It’s important to mention that not all ML models or DL models need interpretability, only those that need contestability. But, as we have seen in this article, interpretability has many advantages even if contestability is not needed, aiding in trust, finding shifts and biases in the training, and allowing debugging.


Hyper-automate your finance department with autonomy

Discover the power of artificial Intelligence's ability to performs cost-side accounting autonomously. Download the free e-guide to to see how cutting edge accounting software does this by learning from your data and your accounting team, then gradually taking over the wheel with autonomous approval flows.

Hyper-automate your accounting software e-guide - autonomous accounting with machine learning

BE A FINANCE PIONEER

Sign up to receive the latest news, updates and valuable accounting and finance resources conveniently in your inbox.