为生产中的模型创建SHAP基准 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

为生产中的模型创建SHAP基准

解释通常是对比性的,也就是说,它们解释了与基准的偏差。有关可解释性基准的信息,请参阅SHAP可解释性基线

除了为每个实例的推断提供解释外, SageMaker Clarify 还支持对机器学习模型进行全局解释,以帮助您从特征的角度了解整个模型的行为。 SageMaker Clarify 通过聚合多个实例上的 Shapley 值来生成机器学习模型的全局解释。 SageMaker Clarify 支持以下不同的聚合方式,您可以使用这些方式来定义基线:

  • mean_abs— 所有实例的绝对SHAP值的平均值。

  • median— 所有实例的SHAP值的中位数。

  • mean_sq— 所有实例的平方SHAP值的平均值。

将应用程序配置为捕获实时或批量转换推理数据后,监控特征归因偏移的第一项任务是创建基准以进行比较。这包括配置数据输入、哪些组是敏感组、如何捕获预测,以及模型及其训练后偏差指标。然后,您需要启动设定基准作业。模型可解释性监控器可以解释正在生成推理的已部署模型的预测,并定期检测特征归因偏移。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在此示例中,可解释性基准作业与偏差基线作业共享测试数据集,因此它使用相同的DataConfig测试数据集,唯一的区别是作业输出。URI

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前,C SageMaker larify 解释器提供了一种可扩展且高效的实现SHAP,因此可解释性配置为SHAPConfig,包括以下内容:

  • baseline— 在内核SHAP算法中用作基线数据集的行(至少一个)或 S3 对象URI的列表。其格式应与数据集格式相同。每行应仅包含要素columns/values and omit the label column/values。

  • num_samples— 要在内核SHAP算法中使用的样本数。此数字决定生成的合成数据集的大小以计算这些SHAP值。

  • agg_method — 全局值的聚合方法。SHAP有效值如下所示:

    • mean_abs— 所有实例的绝对SHAP值的平均值。

    • median— 所有实例的SHAP值的中位数。

    • mean_sq— 所有实例的平方SHAP值的平均值。

  • use_logit - 指示是否将 logit 函数应用于模型预测。默认值为 False。如果use_logitTrue,则这些SHAP值将具有对数赔率单位。

  • save_local_shap_values(bool) — 指示是否将本地SHAP值保存在输出位置。默认值为 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

启动设定基准作业。需要相同的 model_config,因为可解释性设定基准作业需要创建一个影子端点,以获取生成的合成数据集的预测。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")