Improving Interpretable Models based on Knowledge Distillation for ICU Mortality Prediction using Electronic Health Record

Recently, deep learning has shown good performance in various fields such as computer vision, natural language processing, and healthcare. However, deep learning-based models are processed as a black box model and has the disadvantage that it is difficult for humans to understand the reason of the results of the model. The interpretability of the model in the medical field, especially, is extremely important where decision making must be made based on strong evidence. Many studies on mortality prediction, which is a representative healthcare-related application that requires interpretability, have been solved through deep learning-based models. These deep learning show good predictive performance but have a disadvantage in that they have difficulty in providing interpretability for prediction results, while linear models in machine learning have high interpretability but low predictive performance. In this paper, we contribute to improve the performance of linear models using knowledge distillation method. We show the effects of first-order input features and second-order ones on mortality prediction with logistic regression and factorization machine, respectively. In addition, we show the interpretability of the model by visualizing the weight of the trained model. We expect that providing interpretability of the linear models through visualization could help humans intuitively to understand the reason of the results, and furthermore, it could improve the quality of healthcare provided to specific patients in hospitals.


INTRODUCTION
In recent years, deep learning-based models have shown good performances in various applications such as computer vision, natural language processing, and healthcare [6,10,17].The advantage of deep learning is that it has a complex structure that can represent large-scale data and perform many operations [17].However, it operates as a black box model although the performance of deep learning improves, so it is poor for humans to understand the reason of the results.In other words, the interpretability of the prediction results of the model decreased [3,13,17].It is difficult for them to intuitively know how models work, and how much significantly input features affect to results of them.It is crucial to get the information, especially, in medical field that deals with human life, decisions are made based on strong evidence, so the interpretability of models is very important [1,14,19].
Linear models in machine learning, such as linear regression and logistic regression, have high interpretability in general [15].They can show how much input features affected on the output result, also the reason of the prediction results for humans to understand intuitively.Factorization machine is a method widely used in recommender system [18].FM can show the effect of the results through the combinations of two input features.But it is also poorer than the performance of deep learning-based models.
In this paper, we improve the performance of linear models, which are generally interpretable but low performance models.We use logistic regression and factorization machine as linear models and improve the performance of them for Intensive Care Units (ICUs) mortality prediction through knowledge distillation [4,5].Knowledge distillation is one of machine learning methods that the knowledge from a large pretrained neural network model will be transferred to a relatively small and simple model to be trained.The knowledge of the pretrained model can be used as a soft label with reduced noise in the data to improve the performance of the model to be trained [7].Therefore, we improve two linear models by training the knowledge of the existing deep learning model.Also, we present heatmap visualizations for human to interpret that they show how much first-order and second-order input features are affected to the output.The contributions of this paper are summarized as follows.
• We show that the linear models perform better than the previous ones through knowledge distillation by transferring knowledge.• We present model interpretability, which is important in the medical domain through logistic regression and factorization machine, and we show the importance of first-order and second-order global explanation and local explanation by visualizing heat-maps.• We show that the performance problem of existing linear models can be solved with knowledge distillation through experiments.

METHOD 2.1 Linear Models
In machine learning, linear models are highly interpretable models [15].Eq. ( 1) shows the operation of the linear model.Linear models output the prediction results ŷ according to summation of the input feature   and its weight   .Also, they are structures that don't operate nonlinear transformation.Such linear models can show the relationship between a single input feature and output.Logistic Regression (LR) is a representative linear model and a highly interpretable model.That is, it can show the relationship between the input features and the result.
Factorization Machine (FM) was originally proposed as a model for recommender system [18].Factorization machine models all possible interactions of input features, especially it has the advantage of estimating feature interactions of very sparse data.Eq. ( 3) shows the modeling equation of factorization machine.
where  0 +  =1     is a structure that shows the summation of the weights of the input features as shown in Eq. (1).This shows that factorization machine is also a linear model.In the rightmost term, it shows the importance of the input features through the second-order feature interaction.Eq. (4) shows the dot product of two vectors of size .The dot product help improve the performance of linear model, but their performance is poor than that of deep learning models.

Improving Performance with knowledge distillation
Knowledge distillation (KD) is one of machine learning methods that transfer knowledge to small, simple models from large neural networks [8].Knowledge distillation aims to perform better than the existing one by doing mimic learning from the teacher model.Since the logits of teacher models which is the output of the ones from output layers are utilized as soft labels that reduce the noise of the data in the pretrain process, the student model can be less constrained by the noise that can improve the performance of the model [7].In this paper, we focus on the linear models as the student model that are highly interpretable and simple model.

Student Model
(to be trained) The knowledge distillation process used in this paper in shown Figure 1.We first pretrain a teacher model from raw data.And, for the same data, we train the student model using the prediction of the teacher model as the label.This process is a response-based knowledge distilling method, and we use the loss function for training linear models as shown in Eq. ( 5) in response-based knowledge distillation [7].
where   (  ,   ) denotes to response-based distilling loss, L  denotes to the divergence loss of logits, i.e., the loss function used for training.  ,   denotes to the prediction results of the teacher and student models, respectively [4].

Interpretability
We leverage the weights of the trained linear models using knowledge distillation without the additional formulas for interpretability.For global explanation, we calculate the norm of the weights of linear models, i.e., the magnitude of the vectors, to visualize the weights of each feature or field of the data.Eq. ( 6) shows the equation for calculating the norm of the weights.
where  refers 2,   does each element of the weight trained by the linear model.For local explanation of specific sample, we show the interpretability of the trained linear model by visualizing via heatmaps with how much first-order and second-order feature are important to the prediction results.We show the interpretability of the model via heatmap in Section 3.6.), and so on [9].We utilize this data for ICU mortality prediction task.

Preprocessing
We show data preprocessing in this section.We use all variables used in [2].For example, we extract patients who have been admitted to the ICU more than 48 hours, and aged 17 years or older at the time of admission.We include non-temporal variables such as age, AIDS, and lymphoma etc, and temporal ones such as heart rate, temperature etc. in the first 48 hours after admission.We also use the first ICU admission during hospital admission but unlike [2], several ICU samples are included per each patient.We also process the same missing value imputation.Table 1 shows dataset statistics after preprocessing.

Implement Details
We perform stratified 5-fold cross validation to deal with unbalanced data.Standardization was performed on all features, and in    We also use  2 penalty, dropout for the teacher model.Models is implemented with PyTorch 1.10.2[16].

Evaluation Methods
We evaluate the performance of mortality prediction using the area under the receiver operating characteristics (AUROC) curve, the area under the precision-recall (AUCPR) curve.Both are widely used to evaluate the model performance for binary classification with imbalanced data.We evaluate model performance using these metrics for the model with the lowest validation loss for each fold as the best model.

Results and Discussion
Table 2 shows the training results of the teacher and student models for the ICU mortality prediction, respectively, and the results with the knowledge distillation method.In the case of factorization machine, we can see that the improvement of both AUROC and AUCPR for the validation, test dataset.However, in the case of logistic regression, the performance decreased.We believe that logistic regression is poor at training the predictive distribution of the teacher model.Although there was no performance improvement of linear regression, the attempt to show interpretability through linear models is very meaningful.We also use  2 penalty, dropout for the teacher model.Models is implemented with PyTorch 1.10.2[16].

Evaluation Methods
We evaluate the performance of mortality prediction using the area under the receiver operating characteristics (AUROC) curve, the area under the precision-recall (AUCPR) curve.Both are widely used to evaluate the model performance for binary classification with imbalanced data.We evaluate model performance using these metrics for the model with the lowest validation loss for each fold as the best model.

Results and Discussion
Table 2 shows the training results of the teacher and student models for the ICU mortality prediction, respectively, and the results with the knowledge distillation method.In the case of factorization machine, we can see that the improvement of both AUROC and AUCPR for the validation, test dataset.However, in the case of logistic regression, the performance decreased.We believe that logistic regression is poor at training the predictive distribution of the teacher model.Although there was no performance improvement of linear regression, the attempt to show interpretability through linear models is very meaningful.
Figure 4: second-order global explanation of FM.Each row of the figures refers the importance of a combination between some two features for mortality prediction.Each column refers to the influence of the two features at 1 hour and 48 hours after admission to the ICU

Visualization for interpretability
We show global explanation and local explanation of logistic regression and factorization machine trained based on knowledge distillation, respectively [11].logistic regression.Figure 2 shows first-order global explanation of logistic regression.We can see the trained logistic regression model is greatly influenced by BUN, GCSVerbal feature compared to other features across all five folds.Figure 3 shows the first-order local explanation for specific individual data samples.Each column of the figure describes the importance of the feature, including GC-SEyes, FiO2, urine output of true positive (TP) sample, true negative (TN), respectively.We can see that the closer it is to 48 hours after ICU admission, the more the influence of according to each feature to the output.In other words, each feature has a high importance on the mortality prediction result for TP sample whereas a low importance on the TN sample.factorization machine.Figure 4 shows second-order global explanation of factorization machine.Each one shows how the combination affects the prediction results as the range of the two features changes.For example, Figure 4a, 4b describe the importance of both GCSMotor and GCSEyes to mortality prediction in 1 hour, 48 hours after admisison to the ICU, respectively.We can observe that the lower the value of the range of the combination of the two features, the higher the importance of the prediction result.In addition, the result shows that the importance is higher in 48 hours than in 1 hour after admission the ICU. Figure 5 shows secondorder local explanation in 48 hours for specific individual TP, TN sample.Figure 5b shows local explanation of the TP sample in

Visualization for interpretability
show global explanation and local explanation of logistic regression and factorization machine trained based on knowledge distillation, respectively [11].logistic regression.Figure 2 shows first-order global explanation of logistic regression.We can see the trained logistic regression model is greatly influenced by BUN, GCSVerbal feature compared to other features across all five folds.Figure 3 shows the first-order local explanation for specific individual data samples.Each column of the figure describes the importance of the feature, including GC-SEyes, FiO2, urine output of true positive (TP) sample, true negative (TN), respectively.We can see that the closer it is to 48 hours after ICU admission, the more the influence of according to each feature to the output.In other words, each feature has a high importance on the mortality prediction result for TP sample whereas a low importance on the TN sample.factorization machine.Figure 4 shows second-order global explanation of factorization machine.Each one shows how the combination affects the prediction results as the range of the two features changes.For example, Figure 4a, 4b describe the importance of both GCSMotor and GCSEyes to mortality prediction in 1 hour, 48 hours after admisison to the ICU, respectively.We can observe that the lower the value of the range of the combination of the two features, the higher the importance of the prediction result.In addition, the result shows that the importance is higher in 48 hours than in 1 hour after admission the ICU. Figure 5 shows secondorder local explanation in 48 hours for specific individual TP, TN sample.Figure 5b shows local explanation of the TP sample in 48 hours.Although it is seen that most features show the highest importance value of the mortality prediction results when combined with GCSEyes, we can see that most features show the highest importance value, especially when combined with BUN.On the contrary, in the TN sample, we can observe that most of the features show low importance in the prediction result when combined with BUN in Figure 5d.However, the result shows that the combination of GCSEyes and PO2 features shows a relatively high importance to the mortality prediction.Based on these results, we expect that we could provide better healthcare service through the prediction results and these values such as Figure 5a, 5c.

CONCLUSION
In the medical field, where decisions must be made based on strong evidence, interpretability issue is very crucial.Deep learning-based models are processed as a black box model, so they have the disadvantage that it is too hard for humans to understand the reason of the results of the model.Meanwhile, linear models in machine learning are highly interpretable but poor performance than deep learning-based ones.In this paper, we show improving linear models with knowledge distillation method from low performance.We apply this method to ICU mortality prediction, and show not only the improvement of factorization machine, which can show how much the effect of the combination of second-order features to the output, so we can present the visualization via heatmaps, but also suggest first-order global explanation, local explanation of logistic regression.The interpretability of linear models helps humans to understand the processing of the ones intuitively, so we expect that this will be contributed to the healthcare quality to patients.

Figure 3 :
Figure 3: first-order local explanations of LR.The left side figures describe one true positive (TP) sample, while the right side ones describe one true negative (TN) sample .

Figure 3 :
Figure 3: first-order local explanations of LR.The left side figures describe one true positive (TP) sample, while the right side ones describe one true negative (TN) sample .

Figure 4 :Figure 5 :
Figure 4: second-order global explanation of FM.Each row of the figures refers the importance of a combination between some two features for mortality prediction.Each column refers to the influence of the two features at 1 hour and 48 hours after admission to the ICU

Figure 5 :
Figure 5: second-order local explanation of FM.Each row of the figures refers TP, TN sample, repectively.The left side describes raw data during admission in the ICU, while the right side describes the importance of combinations of all features for mortality prediction in 48 hours.

Table 2 :
Results of ICU mortality prediction.The result value indicates the mean (standard deviation) among 5-fold experiments.