為生產中的模型建立SHAP基準 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

為生產中的模型建立SHAP基準

解釋通常是相反的 (也就是說,它們說明偏離基準的情況)。如需可解釋性基準的資訊,請參閱SHAP 可解釋性基準

除了針對每個執行個體推論提供解釋之外, SageMaker Clarify 也支援 ML 模型的全域解釋,協助您了解模型整體行為的特徵。 SageMaker Clarify 透過多個執行個體彙總 Shapley 值來產生 ML 模型的全域解釋。 SageMaker Clarify 支援下列不同的彙總方式,您可以使用它來定義基準:

  • mean_abs – 所有執行個體的絕對SHAP值平均值。

  • median – 所有執行個體SHAP的值中位數。

  • mean_sq – 所有執行個體的平方SHAP值平均值。

將應用程式設定為擷取即時或批次轉換推論資料之後,監控功能屬性偏離的第一項任務就是建立要比較的基準。這包括設定資料輸入、哪些群組是敏感的、如何擷取預測,以及模型及其訓練後的偏差指標。然後,您需要開始進行基準工作。模型可解釋性監控可以解釋已部署模型的預測,該模型會產生推論並定期偵測功能屬性偏離。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在此範例中,可解釋性基礎描述任務與偏差基礎描述任務共用測試資料集,因此使用相同的 DataConfig,唯一的區別是任務輸出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前 SageMaker ,Clarify explainer 提供可擴展且高效率的 實作SHAP,因此可解釋性組態為 SHAPConfig,包括下列項目:

  • baseline – URI要用作核心SHAP演算法中基準資料集的資料列 (至少一個) 或 S3 物件清單。此格式應與資料集格式相同。每一列應該只包含功能欄/值,並省略標籤欄/值。

  • num_samples – 核心SHAP演算法中使用的範例數目。此數字會決定產生的合成資料集大小來計算SHAP值。

  • agg_method – 全域SHAP值的彙總方法。以下為有效值:

    • mean_abs – 所有執行個體的絕對SHAP值平均值。

    • median – 所有執行個體SHAP的值中位數。

    • mean_sq – 所有執行個體的平方SHAP值平均值。

  • use_logit – 是否將 logit 函式套用於模型預測的指示器。預設值為 False。如果 use_logitTrue,則SHAP值將具有 log-odds 單位。

  • save_local_shap_values (布爾) – 指示是否要將本機SHAP值儲存在輸出位置。預設值為 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

開始基準工作。需要相同的 model_config,因為可解釋性基準工作需要建立陰影端點以取得產生的合成資料集的預測。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")