使用 A SageMaker I 估算器来运行训练作业 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 A SageMaker I 估算器来运行训练作业

您还可以使用 Pyth SageMaker on SDK 中的估算器来处理 SageMaker 训练作业的配置和运行。以下代码示例显示如何使用私有 Docker 注册表中的映像配置和运行估算器。

  1. 导入所需的库和依赖项,如以下代码示例中所示。

    import boto3 import sagemaker from sagemaker.estimator import Estimator session = sagemaker.Session() role = sagemaker.get_execution_role()
  2. 向您的训练映像、安全组和子网提供统一资源标识符 (URI),用于您的训练作业 VPC 配置,如以下代码示例所示。

    image_uri = "myteam.myorg.com/docker-local/my-training-image:<IMAGE-TAG>" security_groups = ["sg-0123456789abcdef0"] subnets = ["subnet-0123456789abcdef0", "subnet-0123456789abcdef0"]

    有关security_group_ids和的更多信息subnets,请参阅 Pyth SageMaker on SDK 的 “估算器” 部分中的相应参数描述。

    注意

    SageMaker AI 使用您的 VPC 内的网络连接来访问您的 Docker 注册表中的镜像。要将您 Docker 注册表中的映像用于训练,注册表必须可以从您账户中的 Amazon VPC 访问。

  3. 或者,如果您的 Docker 注册表需要身份验证,则还必须指定向 AI 提供访问凭证 SageMaker 的函数的 AWS Lambda 亚马逊资源名称 (ARN)。以下示例演示了如何指定 ARN。

    training_repository_credentials_provider_arn = "arn:aws:lambda:us-west-2:1234567890:function:test"

    有关使用需要身份验证的 Docker 注册表中的映像的更多信息,请参阅下文中的使用需要身份验证的 Docker 注册表进行训练

  4. 使用前面步骤中的代码示例来配置估算器,如以下代码示例所示。

    # The training repository access mode must be 'Vpc' for private docker registry jobs training_repository_access_mode = "Vpc" # Specify the instance type, instance count you want to use instance_type="ml.m5.xlarge" instance_count=1 # Specify the maximum number of seconds that a model training job can run max_run_time = 1800 # Specify the output path for the model artifacts output_path = "s3://your-output-bucket/your-output-path" estimator = Estimator( image_uri=image_uri, role=role, subnets=subnets, security_group_ids=security_groups, training_repository_access_mode=training_repository_access_mode, training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # remove this line if auth is not needed instance_type=instance_type, instance_count=instance_count, output_path=output_path, max_run=max_run_time )
  5. 使用您的作业名称和输入路径作为参数来调用 estimator.fit,以启动训练作业,如以下代码示例所示。

    input_path = "s3://your-input-bucket/your-input-path" job_name = "your-job-name" estimator.fit( inputs=input_path, job_name=job_name )