Gunakan SageMakerEstimator dalam Pipa Spark - Amazon SageMaker AI

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Gunakan SageMakerEstimator dalam Pipa Spark

Anda dapat menggunakan org.apache.spark.ml.Estimator estimator dan org.apache.spark.ml.Model model, dan SageMakerEstimator estimator dan SageMakerModel model dalam org.apache.spark.ml.Pipeline pipeline, seperti yang ditunjukkan pada contoh berikut:

import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.PCA import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") // substitute your SageMaker IAM role here val roleArn = "arn:aws:iam::account-id:role/rolename" val pcaEstimator = new PCA() .setInputCol("features") .setOutputCol("projectedFeatures") .setK(50) val kMeansSageMakerEstimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(integTestingRole), requestRowSerializer = new ProtobufRequestRowSerializer(featuresColumnName = "projectedFeatures"), trainingSparkDataFormatOptions = Map("featuresColumnName" -> "projectedFeatures"), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(50) val pipeline = new Pipeline().setStages(Array(pcaEstimator, kMeansSageMakerEstimator)) // train val pipelineModel = pipeline.fit(trainingData) val transformedData = pipelineModel.transform(testData) transformedData.show()

Parameter trainingSparkDataFormatOptions mengonfigurasi Spark untuk membuat serial ke protobuf kolom "projectedFeatures" untuk pelatihan model. Selain itu, Spark membuat serial untuk membuat protobuf kolom “label” secara default.

Karena kita ingin membuat kesimpulan menggunakan kolom projectedFeatures "”, kita meneruskan nama kolom ke dalam kolom. ProtobufRequestRowSerializer

Contoh berikut menunjukkan transformasiDataFrame:

+-----+--------------------+--------------------+-------------------+---------------+ |label| features| projectedFeatures|distance_to_cluster|closest_cluster| +-----+--------------------+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...|[880.731433034386...| 1500.470703125| 0.0| | 0.0|(784,[127,128,129...|[1768.51722024166...| 1142.18359375| 4.0| | 4.0|(784,[160,161,162...|[704.949236329314...| 1386.246826171875| 9.0| | 1.0|(784,[158,159,160...|[-42.328192193771...| 1277.0736083984375| 5.0| | 9.0|(784,[208,209,210...|[374.043902028333...| 1211.00927734375| 3.0| | 2.0|(784,[155,156,157...|[941.267714528850...| 1496.157958984375| 8.0| | 1.0|(784,[124,125,126...|[30.2848596410594...| 1327.6766357421875| 5.0| | 3.0|(784,[151,152,153...|[1270.14374062052...| 1570.7674560546875| 0.0| | 1.0|(784,[152,153,154...|[-112.10792566485...| 1037.568359375| 5.0| | 4.0|(784,[134,135,161...|[452.068280676606...| 1165.1236572265625| 3.0| | 3.0|(784,[123,124,125...|[610.596447285397...| 1325.953369140625| 7.0| | 5.0|(784,[216,217,218...|[142.959601818422...| 1353.4930419921875| 5.0| | 3.0|(784,[143,144,145...|[1036.71862533658...| 1460.4315185546875| 7.0| | 6.0|(784,[72,73,74,99...|[996.740157435754...| 1159.8631591796875| 2.0| | 1.0|(784,[151,152,153...|[-107.26076167417...| 960.963623046875| 5.0| | 7.0|(784,[211,212,213...|[619.771820430940...| 1245.13623046875| 6.0| | 2.0|(784,[151,152,153...|[850.152101817161...| 1304.437744140625| 8.0| | 8.0|(784,[159,160,161...|[370.041887230547...| 1192.4781494140625| 0.0| | 6.0|(784,[100,101,102...|[546.674328209335...| 1277.0908203125| 2.0| | 9.0|(784,[209,210,211...|[-29.259112927426...| 1245.8182373046875| 6.0| +-----+--------------------+--------------------+-------------------+---------------+