使用 SageMaker 估算器執行訓練工作 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用 SageMaker 估算器執行訓練工作

您也可以使用 SageMaker Python SDK 中的估算器來處理 SageMaker 訓練工作的設定和執行。下列程式碼範例,示範如何使用私有 Docker 登錄檔的映像來設定及執行估算器。

  1. 匯入必要的程式庫和相依性,如以下程式碼範例所示。

    import boto3 import sagemaker from sagemaker.estimator import Estimator session = sagemaker.Session() role = sagemaker.get_execution_role()
  2. 提供統一資源識別碼 (URI) 給訓練工作的 VPC 組態的訓練映像、安全群組和子網路,如下列程式碼範例所示。

    image_uri = "myteam.myorg.com/docker-local/my-training-image:<IMAGE-TAG>" security_groups = ["sg-0123456789abcdef0"] subnets = ["subnet-0123456789abcdef0", "subnet-0123456789abcdef0"]

    如需security_group_ids和的詳細資訊subnets,請參閱 SageMaker Python SDK 的 [估算器] 一節中的適當參數說明。

    注意

    SageMaker 使用 VPC 中的網路連線來存取 Docker 登錄中的映像檔。若要使用 Docker 登錄檔中的映像進行訓練,必須可以從您帳戶中的 Amazon VPC 存取該登錄檔。

  3. 或者,如果您的 Docker 登錄需要身份驗證,您還必須指定提供存取登入資料的 AWS Lambda 函數的 Amazon 資源名稱 (ARN)。 SageMaker以下程式碼範例說明如何指定 ARN。

    training_repository_credentials_provider_arn = "arn:aws:lambda:us-west-2:1234567890:function:test"

    如需詳細資訊以了解如何在需要驗證的 Docker 登錄檔中使用映像檔,請參閱下方的 使用需要為訓練進行驗證的 Docker 登錄檔

  4. 使用先前步驟的程式碼範例來設定估算器,如下列程式碼範例所示。

    # The training repository access mode must be 'Vpc' for private docker registry jobs training_repository_access_mode = "Vpc" # Specify the instance type, instance count you want to use instance_type="ml.m5.xlarge" instance_count=1 # Specify the maximum number of seconds that a model training job can run max_run_time = 1800 # Specify the output path for the model artifacts output_path = "s3://your-output-bucket/your-output-path" estimator = Estimator( image_uri=image_uri, role=role, subnets=subnets, security_group_ids=security_groups, training_repository_access_mode=training_repository_access_mode, training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # remove this line if auth is not needed instance_type=instance_type, instance_count=instance_count, output_path=output_path, max_run=max_run_time )
  5. 以您的工作名稱和輸入路徑做為參數呼叫 estimator.fit,以開始訓練工作,如下列程式碼範例所示。

    input_path = "s3://your-input-bucket/your-input-path" job_name = "your-job-name" estimator.fit( inputs=input_path, job_name=job_name )