Crie uma linha de base de qualidade do modelo - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Crie uma linha de base de qualidade do modelo

Crie um trabalho de linha de base que compare suas previsões de modelo com rótulos de veracidade em um conjunto de dados de linha de base que você armazenou no Amazon S3. Normalmente, você usa um conjunto de dados de treinamento como o conjunto de dados de linha de base. O trabalho de linha de base calcula as métricas do modelo e sugere restrições a serem usadas para monitorar a variação da qualidade do modelo.

Para criar um trabalho de linha de base, você precisa ter um conjunto de dados que contenha previsões do seu modelo junto com rótulos que representem o Ground Truth para seus dados.

Para criar um trabalho básico, use a ModelQualityMonitor classe fornecida pelo SageMaker SDK Python e conclua as etapas a seguir.

Para criar uma linha de base de qualidade do modelo
  1. Primeiramente, crie uma instância da classe ModelQualityMonitor. O trecho de código a seguir mostra como fazer isso.

    from sagemaker import get_execution_role, session, Session from sagemaker.model_monitor import ModelQualityMonitor role = get_execution_role() session = Session() model_quality_monitor = ModelQualityMonitor( role=role, instance_count=1, instance_type='ml.m5.xlarge', volume_size_in_gb=20, max_runtime_in_seconds=1800, sagemaker_session=session )
  2. Agora, chame o método suggest_baseline do objeto ModelQualityMonitor para executar um trabalho de linha de base. O trecho de código a seguir pressupõe que você tenha um conjunto de dados de linha de base que contém previsões e rótulos armazenados no Amazon S3.

    baseline_job_name = "MyBaseLineJob" job = model_quality_monitor.suggest_baseline( job_name=baseline_job_name, baseline_dataset=baseline_dataset_uri, # The S3 location of the validation dataset. dataset_format=DatasetFormat.csv(header=True), output_s3_uri = baseline_results_uri, # The S3 location to store the results. problem_type='BinaryClassification', inference_attribute= "prediction", # The column in the dataset that contains predictions. probability_attribute= "probability", # The column in the dataset that contains probabilities. ground_truth_attribute= "label" # The column in the dataset that contains ground truth labels. ) job.wait(logs=False)
  3. Após a conclusão do trabalho de linha de base, é possível visualizar as restrições que o trabalho gerou. Primeiro, obtenha os resultados do trabalho de linha de base chamando o método latest_baselining_job do objeto ModelQualityMonitor.

    baseline_job = model_quality_monitor.latest_baselining_job
  4. O trabalho de linha de base sugere restrições, que são limites para métricas que modelam medidas de monitoramento. Se uma métrica ultrapassar o limite sugerido, o Model Monitor relata uma violação. Para visualizar as restrições que o trabalho de linha de base gerou, chame o método suggested_constraints do trabalho de linha de base. O trecho de código a seguir carrega as restrições de um modelo de classificação binária em um dataframe Pandas.

    import pandas as pd pd.DataFrame(baseline_job.suggested_constraints().body_dict["binary_classification_constraints"]).T

    Recomendamos que você visualize as restrições geradas e as modifique conforme necessário antes de usá-las para monitoramento. Por exemplo, se uma restrição for muito agressiva, você poderá receber mais alertas de violações do que gostaria.

    Se sua restrição contiver números expressos em notação científica, você precisará convertê-los em flutuantes. O exemplo de script de pré-processamento de python a seguir mostra como converter números em notação científica em flutuantes.

    import csv def fix_scientific_notation(col): try: return format(float(col), "f") except: return col def preprocess_handler(csv_line): reader = csv.reader([csv_line]) csv_record = next(reader) #skip baseline header, change HEADER_NAME to the first column's name if csv_record[0] == “HEADER_NAME”: return [] return { str(i).zfill(20) : fix_scientific_notation(d) for i, d in enumerate(csv_record)}

    Você pode adicionar seu script de pré-processamento a uma linha de base ou programação de monitoramento como um record_preprocessor_script, conforme definido na documentação do Model Monitor.

  5. Quando estiver satisfeito com as restrições, passe-as como parâmetro constraints ao criar uma programação de monitoramento. Para obter mais informações, consulte Agende trabalhos de monitoramento da qualidade do modelo.

As restrições de linha de base sugeridas estão contidas no arquivo constraints.json no local com o qual você especifica output_s3_uri. Para obter informações sobre o esquema desse arquivo no Esquema para restrições (arquivo constraints.json).