Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.
Wie benutzt man den SageMaker AI Text Classification — TensorFlow Algorithmus
Sie können Text Classification TensorFlow als integrierten Algorithmus von Amazon SageMaker AI verwenden. Im folgenden Abschnitt wird beschrieben, wie Sie Text Classification verwenden — TensorFlow mit dem SageMaker AI Python SDK. Informationen zur Verwendung der Textklassifizierung über die TensorFlow Amazon SageMaker Studio Classic-Benutzeroberfläche finden Sie unterSageMaker JumpStart vortrainierte Modelle.
Der TensorFlow Textklassifizierungsalgorithmus unterstützt Transfer-Learning unter Verwendung eines der kompatiblen vortrainierten TensorFlow Modelle. Eine Liste aller verfügbaren vortrainierten Modelle finden Sie unter TensorFlow Hub-Modelle. Jedes vortrainierte Modell hat ein Unikat model_id
. Im folgenden Beispiel wird BERT Base Uncased (model_id
:tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2
) zur Feinabstimmung eines benutzerdefinierten Datensatzes verwendet. Die vortrainierten Modelle werden alle vorab vom TensorFlow Hub heruntergeladen und in Amazon S3 S3-Buckets gespeichert, sodass Trainingsjobs netzwerkisoliert ausgeführt werden können. Verwenden Sie diese vorgenerierten Modelltrainingsartefakte, um einen AI Estimator zu erstellen. SageMaker
Rufen Sie zunächst den Docker-Image-URI, den Trainingsskript-URI und den vortrainierten Modell-URI ab. Ändern Sie dann die Hyperparameter nach Bedarf. Sie können ein Python-Wörterbuch mit allen verfügbaren Hyperparametern und ihren Standardwerten mit hyperparameters.retrieve_default
sehen. Weitere Informationen finden Sie unter Textklassifizierung — TensorFlow Hyperparameter. Verwenden Sie diese Werte, um einen SageMaker AI-Schätzer zu erstellen.
Anmerkung
Die Standard-Hyperparameterwerte sind für verschiedene Modelle unterschiedlich. Bei größeren Modellen ist die Standardstapelgröße beispielsweise kleiner.
In diesem Beispiel wird der SST2
.fit
indem Sie den Amazon S3-Speicherort Ihres Trainingsdatensatzes verwenden. Jeder in einem Notebook verwendete S3-Bucket muss sich in derselben AWS Region befinden wie die Notebook-Instanz, die darauf zugreift.
from sagemaker import image_uris, model_uris, script_uris, hyperparameters from sagemaker.estimator import Estimator model_id, model_version = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2", "*" training_instance_type = "ml.p3.2xlarge" # Retrieve the Docker image train_image_uri = image_uris.retrieve(model_id=model_id,model_version=model_version,image_scope="training",instance_type=training_instance_type,region=None,framework=None) # Retrieve the training script train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training") # Retrieve the pretrained model tarball for transfer learning train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training") # Retrieve the default hyperparameters for fine-tuning the model hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) # [Optional] Override default hyperparameters with custom values hyperparameters["epochs"] = "5" # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/SST2/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tc-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" # Create an Estimator instance tf_tc_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location, ) # Launch a training job tf_tc_estimator.fit({"training": training_dataset_s3_path}, logs=True)
Weitere Informationen zur Verwendung des SageMaker TensorFlow Textklassifizierungsalgorithmus für Transfer-Lernen in einem benutzerdefinierten Datensatz finden Sie im Notizbuch Einführung in die JumpStart Textklassifikation