Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.
Come usare l'algoritmo SageMaker Image Classification - TensorFlow
Puoi usare Image Classification TensorFlow , come algoritmo SageMaker integrato di Amazon. La sezione seguente descrive come usare la classificazione delle immagini - TensorFlow con SageMaker PythonSDK. Per informazioni su come utilizzare la classificazione delle immagini, TensorFlow dall'interfaccia utente di Amazon SageMaker Studio Classic, consultaSageMaker JumpStart modelli preaddestrati.
L' TensorFlow algoritmo Image Classification supporta il transfer learning utilizzando uno qualsiasi dei modelli TensorFlow Hub preaddestrati compatibili. Per un elenco di tutti i modelli preaddestrati disponibili, consulta TensorFlow Modelli Hub. Ogni modello preaddestrato ne ha model_id
univoco. L'esempio seguente utilizza MobileNet V2 1.00 224 (model_id
:tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4
) per ottimizzare un set di dati personalizzato. I modelli preaddestrati sono tutti pre-scaricati dall' TensorFlow Hub e archiviati in bucket Amazon S3 in modo che i lavori di formazione possano essere eseguiti in isolamento di rete. Usa questi artefatti di addestramento dei modelli pregenerati per costruire un Estimator. SageMaker
Innanzitutto, recupera l'immagine Docker, lo script di addestramento e il modello URI preaddestrato. URI URI Quindi, modifica gli iperparametri per adattarli al tuo caso. Puoi vedere un dizionario Python di tutti gli iperparametri disponibili e i loro valori predefiniti con hyperparameters.retrieve_default
. Per ulteriori informazioni, consulta Classificazione delle immagini - TensorFlow Iperparametri. Usa questi valori per costruire un Estimator. SageMaker
Nota
I valori predefiniti degli iperparametri sono diversi per i diversi modelli. Per i modelli più grandi, la dimensione del batch predefinita è inferiore e l'iperparametro train_only_top_layer
è impostato su. "True"
Questo esempio utilizza il set di dati tf_flowers
.fit
utilizzando la posizione Amazon S3 del tuo set di dati di addestramento.
from sagemaker import image_uris, model_uris, script_uris, hyperparameters from sagemaker.estimator import Estimator model_id, model_version =
"tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4"
, "*" training_instance_type ="ml.p3.2xlarge"
# Retrieve the Docker image train_image_uri = image_uris.retrieve(model_id=model_id,model_version=model_version,image_scope="training",instance_type=training_instance_type,region=None,framework=None) # Retrieve the training script train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training") # Retrieve the pretrained model tarball for transfer learning train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training") # Retrieve the default hyper-parameters for fine-tuning the model hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) # [Optional] Override default hyperparameters with custom values hyperparameters["epochs"] = "5" # The sample training data is available in the following S3 bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tf_flowers/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-ic-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" # Create SageMaker Estimator instance tf_ic_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location, ) # Use S3 path of the training data to launch SageMaker TrainingJob tf_ic_estimator.fit({"training": training_dataset_s3_path}, logs=True)