使用 SageMaker AI 估算器來執行訓練任務 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用 SageMaker AI 估算器來執行訓練任務

您也可以使用 SageMaker Python SDK 中的估算器來處理 SageMaker 訓練工作的組態和執行。下列程式碼範例,說明如何使用私有 Docker 登錄檔的映像來設定及執行估算器。

  1. 匯入必要的程式庫和相依性,如以下程式碼範例所示。

    import boto3 import sagemaker from sagemaker.estimator import Estimator session = sagemaker.Session() role = sagemaker.get_execution_role()
  2. 提供您的訓練映像、安全性群組和子網路的通用資源識別碼 (URI) 給您的訓練工作的 VPC 組態,如下列程式碼範例所示。

    image_uri = "myteam.myorg.com/docker-local/my-training-image:<IMAGE-TAG>" security_groups = ["sg-0123456789abcdef0"] subnets = ["subnet-0123456789abcdef0", "subnet-0123456789abcdef0"]

    如需有關 security_group_idssubnets 的詳細資訊,請參閱 SageMaker Python SDK 的估算器一節有關適當參數的說明。

    注意

    SageMaker AI 會在 VPC 中使用網路連線來存取 Docker 登錄檔中的映像。若要使用 Docker 登錄檔中的映像進行訓練,必須可以從您帳戶中的 Amazon VPC 存取該登錄檔。

  3. 或者,如果您的 Docker 登錄檔需要身分驗證,您還必須指定 函數的 Amazon Resource Name (ARN),該 AWS Lambda 函數提供 SageMaker AI 的存取憑證。以下程式碼範例說明如何指定 ARN。

    training_repository_credentials_provider_arn = "arn:aws:lambda:us-west-2:1234567890:function:test"

    如需詳細資訊以了解如何在需要驗證的 Docker 登錄檔中使用映像檔,請參閱下方的使用需要驗證的 Docker 登錄檔進行訓練

  4. 使用先前步驟的程式碼範例來設定估算器,如下列程式碼範例所示。

    # The training repository access mode must be 'Vpc' for private docker registry jobs training_repository_access_mode = "Vpc" # Specify the instance type, instance count you want to use instance_type="ml.m5.xlarge" instance_count=1 # Specify the maximum number of seconds that a model training job can run max_run_time = 1800 # Specify the output path for the model artifacts output_path = "s3://your-output-bucket/your-output-path" estimator = Estimator( image_uri=image_uri, role=role, subnets=subnets, security_group_ids=security_groups, training_repository_access_mode=training_repository_access_mode, training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # remove this line if auth is not needed instance_type=instance_type, instance_count=instance_count, output_path=output_path, max_run=max_run_time )
  5. 以您的工作名稱和輸入路徑做為參數呼叫 estimator.fit,以開始訓練工作,如下列程式碼範例所示。

    input_path = "s3://your-input-bucket/your-input-path" job_name = "your-job-name" estimator.fit( inputs=input_path, job_name=job_name )