Cara menggunakan Klasifikasi SageMaker Gambar - TensorFlow algoritma - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Cara menggunakan Klasifikasi SageMaker Gambar - TensorFlow algoritma

Anda dapat menggunakan Klasifikasi Gambar - TensorFlow sebagai algoritma SageMaker bawaan Amazon. Bagian berikut menjelaskan cara menggunakan Klasifikasi Gambar - TensorFlow dengan SageMaker PythonSDK. Untuk informasi tentang cara menggunakan Klasifikasi Gambar - TensorFlow dari UI Amazon SageMaker Studio Classic, lihatSageMaker JumpStart model terlatih.

TensorFlow Algoritma Klasifikasi Gambar - mendukung pembelajaran transfer menggunakan salah satu model TensorFlow Hub terlatih sebelumnya yang kompatibel. Untuk daftar semua model terlatih yang tersedia, lihatTensorFlow Model Hub. Setiap model yang telah dilatih sebelumnya memiliki keunikanmodel_id. Contoh berikut menggunakan MobileNet V2 1.00 224 (model_id:tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4) untuk menyempurnakan dataset kustom. Model yang telah dilatih sebelumnya semuanya telah diunduh sebelumnya dari TensorFlow Hub dan disimpan dalam bucket Amazon S3 sehingga pekerjaan pelatihan dapat berjalan dalam isolasi jaringan. Gunakan artefak pelatihan model yang dibuat sebelumnya ini untuk membangun Estimator. SageMaker

Pertama, ambil gambar Docker, skrip pelatihan URIURI, dan model yang telah dilatih sebelumnya. URI Kemudian, ubah hyperparameters sesuai keinginan Anda. Anda dapat melihat kamus Python dari semua hyperparameters yang tersedia dan nilai defaultnya dengan. hyperparameters.retrieve_default Untuk informasi selengkapnya, lihat Klasifikasi Gambar - TensorFlow Hyperparameters. Gunakan nilai-nilai ini untuk membangun SageMaker Estimator.

catatan

Nilai hyperparameter default berbeda untuk model yang berbeda. Untuk model yang lebih besar, ukuran batch default lebih kecil dan train_only_top_layer hyperparameter diatur ke"True".

Contoh ini menggunakan tf_flowersdataset, yang berisi lima kelas gambar bunga. Kami mengunduh dataset sebelumnya dari TensorFlow bawah lisensi Apache 2.0 dan membuatnya tersedia dengan Amazon S3. Untuk menyempurnakan model Anda, hubungi .fit menggunakan lokasi Amazon S3 dari kumpulan data pelatihan Anda.

from sagemaker import image_uris, model_uris, script_uris, hyperparameters from sagemaker.estimator import Estimator model_id, model_version = "tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4", "*" training_instance_type = "ml.p3.2xlarge" # Retrieve the Docker image train_image_uri = image_uris.retrieve(model_id=model_id,model_version=model_version,image_scope="training",instance_type=training_instance_type,region=None,framework=None) # Retrieve the training script train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training") # Retrieve the pretrained model tarball for transfer learning train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training") # Retrieve the default hyper-parameters for fine-tuning the model hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) # [Optional] Override default hyperparameters with custom values hyperparameters["epochs"] = "5" # The sample training data is available in the following S3 bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tf_flowers/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-ic-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" # Create SageMaker Estimator instance tf_ic_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location, ) # Use S3 path of the training data to launch SageMaker TrainingJob tf_ic_estimator.fit({"training": training_dataset_s3_path}, logs=True)