亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何在形狀中繪制錯誤分類的樣本?

如何在形狀中繪制錯誤分類的樣本?

莫回無 2023-10-25 10:28:32
我有一個基因數據集,其引起疾病的可能性得分在 0 到 1 之間(已知得分為 1 的基因會引起疾病,得分為 0.74 的基因可能會引起疾病)。我正在嘗試建立一個機器學習模型來預測回歸分類中新基因的疾病評分。我想查看已知疾病基因但得分較低的基因的形狀決策圖(例如,得分為 1 的基因,但我的模型得分低于 0.8)。我正在努力將這些基因組合在一起進行繪圖。我的數據如下所示:X:Index   Feature1  Feature2   ... FeatureNGene1     1           0.2          10Gene2     1           0.1          7Gene3     0           0.3          10#index is actually the index and not a columnY:Score10.60.4我運行帶有嵌套交叉驗證的 xgboost 回歸器,查看 MSE、預測的 r2,并繪制觀察值與預期值的關系圖。我可以在觀察到的與預期的圖中看到,Y 中得分為 1 的基因有許多模型預測的低分,我想了解為什么模型使用 shap 來做到這一點。不幸的是,我無法提供示例數據。我正在嘗試調整為標簽分類給出的示例 shap 代碼:import shapxgbr = xgboost.XGBRegressor()xgbr.fit(X_train, Y_train)select = range(8) #I have 8 features after feature selection with BorutaShapfeatures = X.iloc[select]features_display = X.loc[features.index]explainer = shap.TreeExplainer(xgbr)expected_value = explainer.expected_value#Example code from https://slundberg.github.io/shap/notebooks/plots/decision_plot.html: y_pred = xgbr.predict(X) y_pred = (shap_values.sum(1) + expected_value) > 0misclassified = y_pred != y_test[select]shap.decision_plot(expected_value, shap_values, features_display, link='logit', highlight=misclassified)我該如何選擇,y_pred以便預測/基因本應為 1,但實際上低于 0.8(或任何低數字)?編輯:為了回應給定的答案,我嘗試過:explainer = shap.TreeExplainer(xgbr)shap_values = explainer.shap_values(X_test)y_pred = xgbr.predict(X_test)m = (y_pred <= 0.5) & (Y_test == 1)shap.initjs()shap.decision_plot(explainer.expected_value, shap_values,  X_test[m],  return_objects=True)它運行但m長度為 171(我的 Y_test 數據中的全部行數),然后該圖繪制了它看起來像的所有 171 - 而且我從查看數據知道應該只有一個基因 <= 0.5 并且但實際上得分為 1。
查看完整描述

2 回答

?
慕尼黑5688855

TA貢獻1848條經驗 獲得超2個贊

首先,你提到在回歸分類中預測新基因的疾病評分,你是什么意思?輸出似乎是二進制的,0或1,因此這是一個二進制分類問題。您應該改用xgboost's 分類器。更新:讓我們根據評論假設一個回歸問題來模擬您的情況。盡管對于下面的示例,我們應該設置'objective':'multi:softmax'為輸出實際標簽。


根據您的問題,您似乎要做的就是在那些未正確預測的樣本上索引測試集,并分析誤導性的特征,這具有一定的意義。


讓我們用一些示例數據集重現您的問題:


from sklearn.datasets import load_iris


from sklearn.model_selection import train_test_split

import shap

import xgboost


X,y = shap.datasets.iris()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)


model = xgboost.train(params={"learning_rate": 0.01}, 

                      dtrain=xgboost.DMatrix(X_train, label=y_train), 

                      num_boost_round =100)

使用整個測試集的 SHAP 圖非常簡單。舉個force_plot例子:


explainer = shap.TreeExplainer(model)

shap_values = explainer.shap_values(X_test)


shap.initjs()

shap.force_plot(explainer.expected_value, shap_values, X_test)

https://img1.sycdn.imooc.com/65387d890001b6f108880363.jpg

現在,如果我們想對錯誤分類的樣本執行相同的操作,我們需要查看輸出概率。由于 iris 數據集有多個類,假設我們想要可視化force_plot那些應該分類為 的樣本2,但我們有一個輸出值如下1.7:


y_pred = model.predict(xgboost.DMatrix(X_test))

m = (y_pred <= 1.7) & (y_test == 2)

現在我們使用掩碼對集合執行布爾索引X_test,并更新shap_values:


shap.initjs()

c= explainer.shap_values(X_test[m])

shap.force_plot(explainer.expected_value, shap_values, X_test[m])

https://img1.sycdn.imooc.com/65387d980001cd5708780362.jpg

這告訴我們,花瓣的長度和寬度主要將回歸推向更高的值。因此,它們可能是在錯誤分類中發揮主要作用的變量。


同樣,對于一個decision_plot:


shap.decision_plot(explainer.expected_value, shap_values, 

                   X_test[m], feature_order='hclust', 

                   return_objects=True)

https://img1.sycdn.imooc.com/65387da50001edf306150210.jpg

查看完整回答
反對 回復 2023-10-25
?
拉莫斯之舞

TA貢獻1820條經驗 獲得超10個贊

由于我沒有您的數據集,因此無法檢查代碼,但這里的一些想法可能會為您指明方向。


看來你沒有訓練你的回歸者。應該是像線一樣


xgbr = xgboost.XGBRegressor()

xgbr.train(X, Y)

現在你可以使用了xgbr.predict(X);)


您還需要培訓解釋員:


explainer = shap.TreeExplainer(xgbr)

with warnings.catch_warnings():

     warnings.simplefilter("ignore")

     sh = explainer.shap_values(X)

現在您可以選擇值:


misclassified = (y_pred <= 0.7) & (Y == 1)

shap.decision_plot(expected_value, sh, features_display, link='logit', highlight=misclassified)

在使用之前,shap我建議您檢查回歸器對數據的擬合程度。因此,為此我建議您將部分數據用于測試,而不是在訓練中使用它。然后,您可以通過計算和比較測試集和訓練集的 MSE 來評估擬合優度。差異越大,預測器的表現就越差。


查看完整回答
反對 回復 2023-10-25
  • 2 回答
  • 0 關注
  • 126 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號