In SageMaker AI Spark for Scala examples,
you use the kMeansSageMakerEstimator
because the example uses the
k-means algorithm provided by Amazon SageMaker AI for model training. You might choose to use
your own custom algorithm for model training instead. Assuming that you have already
created a Docker image, you can create your own SageMakerEstimator
and
specify the Amazon Elastic Container Registry path for your custom image.
The following example shows how to create a KMeansSageMakerEstimator
from the SageMakerEstimator
. In the new estimator, you explicitly
specify the Docker registry path to your training and inference code images.
import com.amazonaws.services.sagemaker.sparksdk.IAMRole
import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer
val estimator = new SageMakerEstimator(
trainingImage =
"811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
modelImage =
"811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
requestRowSerializer = new ProtobufRequestRowSerializer(),
responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(),
hyperParameters = Map("k" -> "10", "feature_dim" -> "784"),
sagemakerRole = IAMRole(roleArn),
trainingInstanceType = "ml.p2.xlarge",
trainingInstanceCount = 1,
endpointInstanceType = "ml.c4.xlarge",
endpointInitialInstanceCount = 1,
trainingSparkDataFormat = "sagemaker")
In the code, the parameters in the SageMakerEstimator
constructor
include:
-
trainingImage
—Identifies the Docker registry path to the training image containing your custom code. -
modelImage
—Identifies the Docker registry path to the image containing inference code. -
requestRowSerializer
—Implementscom.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer
.This parameter serializes rows in the input
DataFrame
to send them to the model hosted in SageMaker AI for inference. -
responseRowDeserializer
—Implementscom.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer
.This parameter deserializes responses from the model, hosted in SageMaker AI, back into a
DataFrame
. -
trainingSparkDataFormat
—Specifies the data format that Spark uses when uploading training data from aDataFrame
to S3. For example,"sagemaker"
for protobuf format,"csv"
for comma-separated values, and"libsvm"
for LibSVM format.
You can implement your own RequestRowSerializer
and
ResponseRowDeserializer
to serialize and deserialize rows from a
data format that your inference code supports, such as .libsvm or ..csv.