使用自訂演算法搭配 Apache Spark 在 Amazon 上 SageMaker 進行模型訓練和託管 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用自訂演算法搭配 Apache Spark 在 Amazon 上 SageMaker 進行模型訓練和託管

在 中SageMaker Spark for Scala 範例,您會使用 ,kMeansSageMakerEstimator因為 範例使用 Amazon 提供的 k 平均演算法 SageMaker 進行模型訓練。不過,您也可以選擇使用專屬的自訂演算法來訓練模型。假設您已建立 Docker 影像,就可以建立您專屬的 SageMakerEstimator,並指定自訂影像的 Amazon Elastic Container Registry 路徑。

以下範例會說明從 SageMakerEstimator 建立 KMeansSageMakerEstimator 的方式。請在新的估算器中明確地指定 Docker 登錄檔路徑,以便訓練和推論程式碼影像。

import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

SageMakerEstimator 建構函式中的參數會包含以下程式碼:

  • trainingImage - 可識別訓練影像的 Docker 登錄檔路徑,該訓練影像包含自訂程式碼。

  • modelImage - 可識別影像的 Docker 登錄檔路徑,該影像包含推論程式碼。

  • requestRowSerializer - 實作 com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer

    此參數會序列化輸入中的資料列DataFrame,以將其傳送至 中託管 SageMaker 的模型以進行推論。

  • responseRowDeserializer - 實作

    com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer.

    此參數會將託管在 中的模型回應還原序列化 SageMaker,再傳回 DataFrame

  • trainingSparkDataFormat - 可指定 DataFrame 訓練資料上傳至 S3 期間,Spark 會使用的資料格式。例如,"sagemaker"protobuf 格式、"csv"逗號分隔值和 "libsvm" LibSVM 格式。

您可以實作專屬的 RequestRowSerializerResponseRowDeserializer,將使用您推論程式碼支援之資料格式 (如 libsvm 或 .csv) 的資料列序列化及還原序列化。