作者:Tyler Folkman
原文:https://blogs.ancestry.com/ancestry/2017/12/18/understanding-machine-learning-xgboost/
为了解释这些技术,我们将使用 Titanic 数据集。该数据集有每个泰坦尼克号乘客的信息(包括乘客是否生还)。我们的目标是预测一个乘客是否生还,并且理解做出该预测的过程。即使是使用这些数据,我们也能看到理解模型决策的重要性。想象一下,假如我们有一个关于最近发生的船难的乘客数据集。建立这样的预测模型的目的实际上并不在于预测结果本身,但理解预测过程可以帮助我们学习如何最大化意外中的生还者。
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import lime.lime_tabular
from sklearn.preprocessing import Imputer
from sklearn.grid_search import GridSearchCV我们要做的首件事是观察我们的数据,你可以在 Kaggle 上找到(https://www.kaggle.com/c/titanic/data)这个数据集。拿到数据集之后,我们会对数据进行简单的清理。即:
把分类变量转化为虚拟变量
这些清洗技巧非常简单,本文的目标不是讨论数据清洗,而是解释 XGBoost,因此这些都是快速、合理的清洗以使模型获得训练。
y = data.Survived
X = pd.get_dummies(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
pipeline = Pipeline([( imputer , Imputer(strategy= median )), ( model , XGBClassifier())])
model__learning_rate=[.01, .1],
cv = GridSearchCV(pipeline, param_grid=parameters)接着查看测试结果。为简单起见,我们将会使用与 Kaggle 相同的指标:准确率。
print("Test Accuracy: {}".format(accuracy_score(y_test, test_predictions)))至此我们得到了一个还不错的准确率,在 Kaggle 的大约 9000 个竞争者中排到了前 500 名。因此我们还有进一步提升的空间,但在此将作为留给读者的练习。
fi = list(zip(X.columns, cv.best_estimator_.named_steps[ model ].feature_importances_))
top_10 = fi[:10]
y = [x[1] for x in top_10]nerror="javascript:errorimg.call(this);">
我们可以很清楚地看到,那些生还者相比遇难者的平均票价要高得多,因此把票价当成重要特征可能是合理的。
这种个体层次上的分析对于生产式机器学习系统可能非常有用。考虑其它例子,使用模型预测是否可以某人一项贷款。我们知道信用评分将是模型的一个很重要的特征,但是却出现了一个拥有高信用评分却被模型拒绝的客户,这时我们将如何向客户做出解释?又该如何向管理者解释?
接下来我们尝试在模型中应用 LIME。基本上,首先需要定义一个处理训练数据的解释器(我们需要确保传递给解释器的估算训练数据集正是将要训练的数据集):
explainer = lime.lime_tabular.LimetabularExplainer(X_train_imputed,
class_names=["Not Survived", "Survived"],随后你必须定义一个函数,它以特征数组为变量,并返回一个数组和每个类的概率:
def xgb_prediction(X_array_in):
X_array_in = np.expand_dims(X_array_in, 0)最后,我们传递一个示例,让解释器使用你的函数输出特征数和标签:
exp = explainer.explain_instance(X_test_imputed[1], xgb_prediction, num_features=5, top_labels=1)在这里我们有一个示例,76% 的可能性是不存活的。我们还想看看哪个特征对于哪个类贡献最大,重要性又如何。例如,在 Sex = Female 时,生存几率更大。让我们看看柱状图:
nerror="javascript:errorimg.call(this);">
看起来 Pclass 等于 2 的存活率还是比较低的,所以我们对于自己的预测结果有了更多的理解。看看 LIME 上展示的 top5 特征,看起来这个人似乎仍然能活下来,让我们看看它的标签:
>>>1本文为读者提供了一个简单有效理解 XGBoost 的方法。希望这些方法可以帮助你合理利用 XGBoost,让你的模型能够做出更好的推断。
注:本文为机器之心编译,转载请联系本公众号获得授权。
nerror="javascript:errorimg.call(this);">
