Use Custom Algorithms for Model Training and Hosting on Amazon SageMaker AI with Apache Spark
In SageMaker AI Spark for Scala examples,
you use the kMeansSageMakerEstimator
because the example uses the
k-means algorithm provided by Amazon SageMaker AI for model training. You might choose to use
your own custom algorithm for model training instead. Assuming that you have already
created a Docker image, you can create your own SageMakerEstimator
and
specify the Amazon Elastic Container Registry path for your custom image.
The following example shows how to create a KMeansSageMakerEstimator
from the SageMakerEstimator
. In the new estimator, you explicitly
specify the Docker registry path to your training and inference code images.
import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")
In the code, the parameters in the SageMakerEstimator
constructor
include:
-
trainingImage
—Identifies the Docker registry path to the training image containing your custom code. -
modelImage
—Identifies the Docker registry path to the image containing inference code. -
requestRowSerializer
—Implementscom.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer
.This parameter serializes rows in the input
DataFrame
to send them to the model hosted in SageMaker AI for inference. -
responseRowDeserializer
—Implementscom.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer
.This parameter deserializes responses from the model, hosted in SageMaker AI, back into a
DataFrame
. -
trainingSparkDataFormat
—Specifies the data format that Spark uses when uploading training data from aDataFrame
to S3. For example,"sagemaker"
for protobuf format,"csv"
for comma-separated values, and"libsvm"
for LibSVM format.
You can implement your own RequestRowSerializer
and
ResponseRowDeserializer
to serialize and deserialize rows from a
data format that your inference code supports, such as .libsvm or ..csv.