為生產中的模型建立 SHAP 基準 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

為生產中的模型建立 SHAP 基準

解釋通常是相反的 (也就是說,它們說明偏離基準的情況)。如需可解釋性基準的資訊,請參閱用於可解釋性的 SHAP 基準

除了針對每個執行個體推論提供說明之外,SageMaker Clarify 也支援機器學習 (ML) 模型的全域說明,協助您根據模型的功能來瞭解整個模型的行為。SageMaker Clarify 透過彙總多個執行個體的 Shapley 值,產生機器學習 (ML) 模型的全域說明。SageMaker Clarify 支援下列不同的彙總方式,您可以使用這些方式來定義基準:

  • mean_abs – 所有執行個體的絕對 SHAP 值的平均值。

  • median – 所有執行個體的 SHAP 值的中間值。

  • mean_sq – 所有執行個體的平方 SHAP 值的平均值。

將應用程式設定為擷取即時或批次轉換推論資料之後,監控功能屬性偏離的第一項任務就是建立要比較的基準。這包括設定資料輸入、哪些群組是敏感的、如何擷取預測,以及模型及其訓練後的偏差指標。然後,您需要開始進行基準工作。模型可解釋性監控可以解釋已部署模型的預測,該模型會產生推論並定期偵測功能屬性偏離。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在這個範例中,可解釋性基準工作與偏差基準工作共用測試資料集,因此它使用相同的 DataConfig,唯一的差異是任務輸出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前 SageMaker Clarify 解釋器提供了 SHAP 的可擴展性和高效,因此可解釋性組態是 ShapConfig,包括以下內容:

  • baseline – 要在核心 SHAP 演算法中用作基準資料集的資料列 (至少一個) 或 S3 物件 URI 的清單。此格式應與資料集格式相同。每一列應該只包含功能欄/值,並省略標籤欄/值。

  • num_samples – 要在核心 SHAP 演算法中使用的樣本數。此數字決定產生的合成資料集的大小來計算 SHAP 值。

  • agg_method — 全域 SHAP 值的彙總方法。以下為有效值:

    • mean_abs – 所有執行個體的絕對 SHAP 值的平均值。

    • median – 所有執行個體的 SHAP 值的中間值。

    • mean_sq – 所有執行個體的平方 SHAP 值的平均值。

  • use_logit – 是否將 logit 函式套用於模型預測的指示器。預設值為 False。如果 use_logitTrue,SHAP 值將有對數機率單位。

  • save_local_shap_values (bool) – 是否將本機 SHAP 值儲存在輸出位置的指示器。預設值為 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

開始基準工作。需要相同的 model_config,因為可解釋性基準工作需要建立陰影端點以取得產生的合成資料集的預測。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")