本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
使用自訂演算法搭配 Apache Spark 在 Amazon 上 SageMaker 進行模型訓練和託管
在 中SageMaker Spark for Scala 範例,您會使用 ,kMeansSageMakerEstimator
因為 範例使用 Amazon 提供的 k 平均演算法 SageMaker 進行模型訓練。不過,您也可以選擇使用專屬的自訂演算法來訓練模型。假設您已建立 Docker 影像,就可以建立您專屬的 SageMakerEstimator
,並指定自訂影像的 Amazon Elastic Container Registry 路徑。
以下範例會說明從 SageMakerEstimator
建立 KMeansSageMakerEstimator
的方式。請在新的估算器中明確地指定 Docker 登錄檔路徑,以便訓練和推論程式碼影像。
import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")
SageMakerEstimator
建構函式中的參數會包含以下程式碼:
-
trainingImage
- 可識別訓練影像的 Docker 登錄檔路徑,該訓練影像包含自訂程式碼。 -
modelImage
- 可識別影像的 Docker 登錄檔路徑,該影像包含推論程式碼。 -
requestRowSerializer
- 實作com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer
。此參數會序列化輸入中的資料列
DataFrame
,以將其傳送至 中託管 SageMaker 的模型以進行推論。 -
responseRowDeserializer
- 實作com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer
.此參數會將託管在 中的模型回應還原序列化 SageMaker,再傳回
DataFrame
。 -
trainingSparkDataFormat
- 可指定DataFrame
訓練資料上傳至 S3 期間,Spark 會使用的資料格式。例如,"sagemaker"
protobuf 格式、"csv"
逗號分隔值和"libsvm"
LibSVM 格式。
您可以實作專屬的 RequestRowSerializer
和 ResponseRowDeserializer
,將使用您推論程式碼支援之資料格式 (如 libsvm 或 .csv) 的資料列序列化及還原序列化。