Interpretable and Steerable Sequence Learning via Prototypes

A Brief Introduction

The interpretability of machine learning models is becoming increasingly important for crucial decision-making scenarios. It is especially challenging for deep learning models which consist of massive parameters and complicated architectures.

ProSeNet (Prototype Sequence Network) is a sequence model that is interpretable while retaining the accuracy of sequence neural networks like RNN and LSTM.

ProSeNet is interpretable in the sense that its predictions are produced under a case-based reasoning framework, which naturally generates explanations by comparing the input to the typical cases. For each input sequence, the model computes its similarity with the prototype sequences that we learned form the training data, and computes the final prediction by consulting the most similar prototypes. For example, the prediction and explanation of a sentiment classifier for text based on ProSeNet would be something like:

Here the numbers (0.69, 0.30) indicates the similarity between the inputs and the prototypes. You may find such a framework is similar to k-nearest neighbor models. And yes, the idea of the model originates in classical k-nearest neighbors and metrics learning!

ProSeNet

The architecture of the model is illustrated as in the upmost figure. It uses an LSTM encoder r to map sequences to a fixed embedding space, and learns a set of k prototype vectors that are used as a basis for inference. The embedding of the sequence is compared with the prototype vectors and produces k similarity scores. Then a fully connected layer f is used to produce the final output. Here the weight of the layer f assigns the relation between the prototypes and the final classes.

However, the model is still not interpretable, cuz the prototypes are vectors in the embedding space! Thus, we use a projection technique to replace the prototype vectors by its closest embedding vector during training, which associates each prototype vector with a "real" readable sequence. For more details, please check our paper listed below.

ProtoSteer

To assist the interaction between the user and the model. We further designed a full interaction model which makes the incorporation of user knowledge much easier. The key idea of the interaction model is illustrated as below:

The models can be explained and diagnosed by presenting the prototypes, the similar cases of the prototypes, or by comparing different prototypes. The users can then gain knowledge about the model and dataset from the explanations. Then combined with their own knowledge, the users could generate feedback, such as creating new prototypes, updating existing ones, or deleting bad ones. For example, the user might disagree with some prototypes, and would like to correct them. And finally, we update the knowledge of the model using the feedback from the user through a constrained fine-tuning process.

Visual Interface

The example of the user interface and interactions is shown below:

Publication

  • Yao Ming, Panpan Xu, Huamin Qu, Liu Ren. Interpretable and Steerable Sequence Learning via Prototypes. Published in Proceedings of KDD 19. [preprint]
  • Yao Ming, Panpan Xu, Furui Cheng, Huamin Qu, Liu Ren. ProtoSteer: Steering Deep Sequence Model with Prototypes. The paper would be presented on IEEE VIS 2019, and published on TVCG. [ieee]

Video Preview

Code

The code of ProSeNet and ProtoSteer will be hosted on github at here.

We are actively working on the code review to meet legal and copyright requirements of Bosch. The code will be release once the review is finished.

Acknowledgements

Most part of the work are done during Yao's internship at Bosch Research North America.