Criar uma linha de base do SHAP para modelos em produção - Amazon SageMaker

Criar uma linha de base do SHAP para modelos em produção

As explicações são tipicamente contrastivas, ou seja, elas explicam os desvios de uma linha de base. Para obter informações sobre linhas de base de explicabilidade, consulte Linhas de base do SHAP para explicabilidade.

Além de fornecer explicações para inferências por instância, o SageMaker Clarify também oferece apoio à explicação global para modelos de ML que ajudam você a entender o comportamento de um modelo como um todo em termos de seus atributos. O SageMaker Clarify gera uma explicação global de um modelo de ML agregando os valores Shapley em várias instâncias. O SageMaker Clarify é compatível com as seguintes formas diferentes de agregação, que você pode usar para definir linhas de base:

  • mean_abs: Média dos valores SHAP absolutos para todas as instâncias.

  • median: Média dos valores SHAP para todas as instâncias.

  • mean_sq: Média dos valores SHAP quadráticos para todas as instâncias.

Depois de configurar sua aplicação para capturar dados de inferência em tempo real ou de transformação de lotes, a primeira tarefa para monitorar o desvio da atribuição de atributos é criar uma linha de base para comparação. Isso envolve configurar as entradas de dados, quais grupos são confidenciais, como as predições são capturadas e o modelo e suas métricas de desvio pós-treinamento. Em seguida, você precisa iniciar o trabalho de linha de base. O monitor de explicabilidade do modelo pode explicar as predições de um modelo implantado que está produzindo inferências e detectar desvios na atribuição de atributos regularmente.

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

Neste exemplo, o trabalho de linha de base de explicabilidade compartilha o conjunto de dados de teste com o trabalho de linha de base de desvio, então ele usa o mesmo DataConfig e a única diferença é o URI de saída do trabalho.

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

Atualmente, o explicador do SageMaker Clarify oferece uma implementação escalável e eficiente do SHAP; dessa forma, a configuração de explicabilidade é SHAPConfig, incluindo o seguinte:

  • baseline: Uma lista de linhas (pelo menos uma) ou URI do objeto S3 a ser usada como conjunto de dados de linha de base no algoritmo SHAP do Kernel. O formato deve ser igual ao formato do conjunto de dados. Cada linha deve conter somente as colunas/valores do atributo e omitir a coluna/valores do rótulo.

  • num_samples: Número de amostras a serem usadas no algoritmo SHAP do Kernel. Esse número determina o tamanho do conjunto de dados sintético gerado para calcular os valores SHAP.

  • agg_method: Método de agregação para valores globais de SHAP. Estes são valores válidos:

    • mean_abs: Média dos valores SHAP absolutos para todas as instâncias.

    • median: Média dos valores SHAP para todas as instâncias.

    • mean_sq: Média dos valores SHAP quadráticos para todas as instâncias.

  • use_logit: Indicador de se a função logit deve ser aplicada às predições de modelo. O padrão é False. Se o use_logit for True, os valores SHAP terão unidades logarítmicas de probabilidades.

  • save_local_shap_values (bool): Indicador de se os valores SHAP locais devem ser salvos no local de saída. O padrão é False.

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

Inicie um trabalho de linha de base. A mesma model_config é necessária porque o trabalho de definição de base de explicabilidade precisa criar um endpoint de sombra para obter predições para o conjunto de dados sintéticos gerado.

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")