rrimyuu 님의 블로그

Paper review | XAI–reduct: accuracy preservation despite dimensionality reduction for heart disease classification using explainable AI 본문

Neuroscience/Artificial Intelligence

Paper review | XAI–reduct: accuracy preservation despite dimensionality reduction for heart disease classification using explainable AI

rrimyuu 2024. 7. 31. 15:33

1. Introduction 

  • 심장병 (Heart disease) 은 전 세계적으로 흔한 사망 원인 중 하나임.
  • ML 도구 및 등장은 의사가 질병의 특징과 증상과 같은 이유와 사망을 식별하는 데 도움. 
  • 지금까지 ML 모델을 사용한 심장병 연구는 설명이 부족함. 
  • ML 모델을 설명하는 데 가장 널리 사용되는 프레임워크는 LIME, SHapley Additive exPlanations (SHAP) 임. 
  • 본 연구는 심장병 분류를 설명하기 위해 SHAPASH, PDP 및 DALEX 모델을 사용한 최초의 연구임. 
  • ML 모델로 의 decision tree, random forest, support vector machine, and XGBoost 사용. 
  • 설명에 따라 특징 기여도 (feature contributions) 와 특징 가중치 (feature weights) 를 고려하여 특징 선택 (feature selection) 을 수행함.
    -> 선별된 특징은 추가 분류를 위해 XGBoost 모델에 입력되었음.

 

 

2. Related work 

(TODO) 25, 26, 27, 28, 29, -> 최근 연구로 추가  

3. Dataset description 

  • The dataset used in this study is publicly available.
  • This dataset has been chosen as the feature vector comprised of selected and noted features from four famous large datasets from 1988 (Cleveland, Hungary, Switzerland, and Long Beach V), which have been used for heart diseases widely. 
  • The dataset comprised thirteen features and one target variable.
  • Five correlations were done on the feature vectors to observe the level of correlation

 

4. Proposed classification using four machine learning models

  • The classification for heart diseases was done using four machine learning models: decision trees, support vector machines, random forest, and Xtreme Gradient Boosting. 
  • Further to the classification, the models were explained using SHAP. 
  • XGBoost is the highest among all the four models.
  • The false negative values for XGBoost also reflect a minimum count.
    (Since the models are being applied in healthcare data, more precisely in heart disease data, the lower the false negatives, the better the prediction model.)

5. Proposed explainable classification using four machine learning models

5.1. SHapley additive exPlanations (SHAP)

  • (Table 5) Decision tree와 Random forest 에서 <BIAS> 가 가장 높은 contribution 을 가짐. 따라서 이러한 모델은 무시됨. 
  • (Fig 6) fbs (fasting_blood_sugar) 의 기여도가 가장 낮음. 이에 따라 fbs 를 제외한 feature subset (reduced feature subset) 를 통해 XGBoost 를 학습함. 

Fig 6. SHAP explanation for XGBoost model
Table 5. Feature and relevant contributions for the four models post explanation using SHAP

  • reduced feature set 를 통해 XGBoost 를 학습한 결과임. 모델 정확도는 93.18% 관찰되었는데, 모든 feature 를 사용하여 학습한 모델보다 낮았지만 그럼에도 불구하고 다른 3개 ML 모델보다 정확도가 높았음. 
  • reduced feature vector 를 사용하면 XGBoost 가 다른 3개 ML 모델보다 더 나은 성능을 보이며 분류에 필요한 리소스 또한 적다는 결론을 내릴 수 있음. 
  • (Table 6) <BIAS> 에 대한 기여도가 상당히 떨어졌다는 점이 흥미로움. 

Table 6. Feature, weight, and contribution for XGBoost classifier using reduced feature set
Fig 7. SHAP explanation for XGBoost model using reduced feature vector

5.1.1. Dependence plot

  • Dependence plot 은 SHAP 에서 특정 input feature 와 모델의 output 간의 관계를 탐색하는 데 사용되는 시각화 유형. 
  • input 의 다른 feature 값을 고려하면서 feature 값 변화가 모델 예측 결과에 어떻게 영향을 미치는지 보여줌. 
  • (Fig 8) 가장 모델에 큰 기여를 하는 chest_pain_type 값에 따른 다른 상위 4가지 features 과의 비선형 관계를 보여줌.  

Fig 8. Dependence plot for five important features (chest_pain_type, st_depression, num_major_vessels, thalassemia, and age). This figure shows that chest_pain_type has the highest contribution in model prediction as compared to other features and it also shows the nonlinear relationship with other top four features (st_depression, thalassemia, num_major_vessels, and age)

5.1.2. Waterfall plot

  • 특정 instance 에 대해 예상되는 모델 출력에 대한 feature 기여도를 보여줌. 
  • 특정 instance 에 대한 waterfall plot 을 통해 예상되는 모델 출력에 대한 각 feature 의 긍정적 기여도, 부정적 기여도를 시각화함. 
  • 특정 instance 에 대한 모델 출력의 예상 출력 값과 함께 실제 출력 값도 함께 나타냄. 
  • SHAP 값은 예상 출력 값과 실제 출력 값의 차이에 대해 feature 의 기여도를 정량화 하는 것임.  
  • (Waterfall plot) 누적된 결과를 어떻게 나눠가져가는지 파악하기 위함. 
  • (Fig 9)  Shap value 를 절대값으로 내림차순 정렬. 막대 그래프의 시작점과 종료 지점을 계산해서 그래프 출력. 

Fig 9. Waterfall plot for all the features. This figure shows that  chest_pain_type  has positive contribution to the predicted output of the model, whereas, the feature  sex  has negative contribution to the model prediction. The expected value of the model output is  0.094,  and the actual output is  0.762

5.2. SHAPASH

  • 설명 가능한 ML 모델을 구축하고 배포하기 위한 사용하기 쉽고 사용자 정의가 가능한 프레임워크를 제공함.
  • 사용자에게 모델 성능을 시각화, 이해 및 설명하기 위한 도구를 제공하여 ML 모델을 구축하고 배포하는 프로세스를 간소화하도록 설계되었음. 

Fig 10. SHAPASH visualization for the complete feature vector related to feature importance
Fig 11. SHAPASH visualization for four features (chest_pain_type, thalassemia, num-major-vessels, and st_slope). Here, for feature  chest_pain_type , 0 gives negative contribution but 1, 2, and 3 gives positive contribution for model output prediction
Fig 13. SHAPASH visualization for four features (st_depression, age, cholesterol, and max_heart_rate_achieved). In this figure, all the features give the positive and negative contribution to the model prediction

5.3. Local interpretable model agnostic explanations (LIME)

Fig 16. LIME explanation of the feature importance in classification, in this present study. It visualizes the contribution of features for model prediction in certain instance. The value associated with every feature gives the contribution for heart disease classification

5.4. Descriptive mAchine learning eXplanations (DALEX)

  • DALEX 는 ML 모델에 대한 독립적인 설명 및 탐색을 위한 포괄적인 도구를 제공. 
  • 각 feature 가 모델 예측에 기여하는 정도를 측정하고 feature 과 모델 예측 간의 관계를 보여줌. 

Fig 17. Feature importance values for a particular instance of the input data using DALEX for breakdown
Fig 18. Feature importance values for a particular instance of the input data using DALEX for Shapely values. The figure visualizes the contribution of every feature for model prediction for a certain instance. The value associated with every feature gives the contribution

5.5. Partial dependency plots (PDP) 

  • 특정 feature 와 모델 예측 결과 간의 관계를 보여주는 시각화 도구임. 
  • 특정 feature 만 변화시키고 다른 모든 feature 들을 일정하게 유지할 때 모델이 예측한 타깃 변수 (target variable) 의 평균. 
  • 관심 feature 를 선택한 다음 해당 feature 에 대한 그리드를 만듦. 이후 ML 모델을 사용하여 다른 모든 feature 들을 일정하게 유지하면서 그리드의 각 값에 대한 타깃 변수를 예측함. 예측된 타깃 변수에 대해 feature 값을 플로팅하여 관계를 시각화 함. 
  • 음영 처리된 파란색 영역은 모델의 분산을 기반으로 예측 결과에 대한 불확실성의 범위를 나타냄. 

Fig 19. Partial dependency plot for four features (chest_pain_type, st_depression, thalassemia, and num_major_vessels). The figure represents the average prediction of the target variable by the model when one feature is varied while holding all others constant

6. Comparative analysis 

  • 설명 가능한 ML 접근 방식을 사용한 heart disease 분류 연구는 적음. 본 연구에선 accuracy 점수로만 비교하였음. 
  • (Table 7) Proposed work - 1 결과 (test accuracy = 97.86%) 는 타 기존 연구들보다 더 나은 정확도를 보임. 
  • (Table 7) Proposed work - 2 결과 (training accuracy = 97.63%, test accuracy = 93.18%) 는 training accuracy 를 고려할 경우, 타 기존 연구들보다 더 나은 정확도를 보였고, 9개 특징만을 사용했으므로 더 나은 성과를 거둔 거라 주장할 수 있음.  

Table 7. Comparative analysis of heart disease prediction proposals using explainable machine learning

6.1. Correctness (by Hosain, M. T., Jim, J. R., Mridha, M. F., & Kabir, M. M. (2024). Explainable AI approaches in deep learning: Advancements, applications and challenges. Computers and Electrical Engineering, 117, 109246.

  • XAI 모델에서 얻은 상위 5개 feature 를 비교함. 
  • The dependency on specific feature sets for correctness evaluation might hinder the model’s adaptability to datasets with varying feature distributions, limiting its generalizability. (위 언급된 XAI 논문에서 기재한 본 논문의 limitation 임) 

Table 8. Top five features in order of contributions obtained from the XAI techniques
Table 9. Occurrence chart for each of the top five features in Table 8

7. Conlusion 

  • 정확도 측면에서 XGBoost 모델이 다른 모델들보다 성능이 더 우수함.. 
  • XGBoost 모델의 전체 feature vector 와 reduced feature vector 에 대한 분류 정확도는 대부분 기존 문헌보다 우수했음. 
  • 설명 가능 모델 5개를 통해 진단에 기여하는 상위 4가지 feature 를 선별함. (각 feature 가 기여한 정도를 반영하여 다섯 가지 feature 중 공통적으로 발생하는 특성들을 관찰하였음.)