本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
培訓 Amazon Rekognition 自訂標籤模型
您可以使用 Amazon Rekognition 自訂標籤主控台或 Amazon Rekognition 自訂標籤 API 來培訓模型。如果模型培訓失敗,請使用中 偵錯失敗的模型訓練 的資訊來尋找失敗的原因。
您需根據成功培訓模型所需的時間付費。通常培訓需要 30 分鐘至 24 小時才能完成。如需更多詳細資訊,請參閱 培訓時間。
每次培訓模型時都會建立一個模型的新版本。Amazon Rekognition 自訂標籤會為模型建立一個名稱,該名稱是專案名稱和建立模型時的時間戳記的組合。
為了培訓您的模型,Amazon Rekognition 自訂標籤會製作來源培訓和測試圖像的副本。在預設情況下,複製的圖像會使用 AWS 擁有和管理的金鑰進行靜態加密。您也可以選擇使用自己的 AWS KMS key。如果您使用自己的 KMS 金鑰,則需要擁有該 KMS 金鑰的以下權限。
公里:CreateGrant
公里:DescribeKey
如需更多詳細資訊,請參閱 AWS 金鑰管理服務概念。您的來源圖像不會受到影響。
您可以使用 KMS 伺服器端加密 (SSE-KMS) 來加密 Amazon S3 儲存貯體中的培訓和測試圖像,然後再由 Amazon Rekognition 自訂標籤進行複製。若要允許 Amazon Rekognition 自訂標籤存取您的映像檔,您的 AWS 帳戶需要下列 KMS 金鑰的許可。
如需更多詳細資訊,請參閱 使用儲存在 AWS 金鑰管理服務 (SSE-KMS) 中的 KMS 金鑰進行伺服器端加密來保護資料。
培訓模型後,您可以評估其效能並進行改進。如需詳細資訊,請參閱 改善訓練過的 Amazon Rekognition 自訂標籤模型。
如需了解其他模型工作 (例如標籤模型),請參閱 管理 Amazon Rekognition 自訂標籤模型。
培訓模型(主控台)
您可以使用 Amazon Rekognition 自訂標籤主控台來培訓模型。
培訓需要具備培訓資料集和測試資料集的專案。如果您的專案沒有測試資料集,Amazon Rekognition 自訂標籤主控台會在培訓期間分割培訓資料集,以便為您的專案建立一個資料集。選擇的圖像是具代表性的取樣,不會用於培訓資料集。我們建議您只有在沒有可使用的替代測試資料集時,才分割培訓資料集。分割培訓資料集會減少可用於培訓的圖像數量。
系統會根據培訓模型所需的時間來向您收取費用。如需更多詳細資訊,請參閱 培訓時間。
培訓您的模型(主控台)
開啟 Amazon Rekognition 主控台:https://console.aws.amazon.com/rekognition/。
選擇 使用自訂標籤。
在左側導覽視窗中,選擇 專案。
在 專案 頁面,選擇包含要培訓的模型的專案。
在 專案 頁面上,選擇 培訓模型。
(可選) 如果您想要使用自己的 AWS KMS 加密金鑰,請執行以下操作:
在 圖像資料加密 中選擇 自訂加密配置 (進階)。
在 encryption.aws_kms_key,輸入您的金鑰的 Amazon Resource Name (ARN),或選擇現有的 AWS KMS key。若要建立新的金鑰,請選擇 建立 AWS IMS 金鑰。
(可選) 如果您要新增標籤到模型,請執行以下操作:
在 標籤 區域,選擇 新增。
輸入下列資料:
金鑰 中的金鑰名稱。
值 中的鍵/值。
若要新增更多標籤,請重複步驟 6a 和 6b。
(選用) 如果您要移除標籤,請選擇要刪除的標籤旁邊的刪除。如果您要移除先前儲存的標籤,則會在您儲存變更時移除該標籤。
在 培訓模型 的頁面上,選擇 培訓模型。專案的 Amazon Resource Name (ARN) 應該在 選擇專案 的編輯框中。如果沒有,請輸入專案的 ARN。
在您是否要訓練模型?的對話框中,選擇訓練模型。
在專案頁面的 模型 區域中,您可以在培訓正在進行的 Model Status
欄位中檢查目前狀態。培訓模型需要一段時間才能完成。
訓練完成後,請選擇模型名稱。當模型狀態轉為 培訓_完成 時,即培訓已經完成。如果培訓失敗,請參閱 偵錯失敗的模型訓練。
下一步:評估您的模型。如需更多詳細資訊,請參閱 改善訓練過的 Amazon Rekognition 自訂標籤模型。
培訓模型 (SDK)
您可以透過呼叫來訓練模型CreateProjectVersion。若要培訓模型,需要以下資料:
培訓使用與專案相關的培訓和測試資料集。如需詳細資訊,請參閱 管理資料集。
或者,您可以指定項目外部的培訓和測試資料集清單檔案。如果您在使用外部清單檔案培訓模型後開啟主控台,Amazon Rekognition 自訂標籤將使用最後一組用於培訓的清單檔案為您建立資料集。您無法再透過指定外部資訊清單檔案來培訓專案的模型版本。如需詳細資訊,請參閱CreatePrjectVersion。
來自 CreateProjectVersion
的回應是一個 ARN,您可以使用它來識別後續請求中的模型版本。您也可以使用 ARN 來保護模型版本。如需詳細資訊,請參閱 保護亞馬遜重新認知自訂標籤專案。
培訓模型版本需要一段時間才能完成。本主題中的 Python 和 Java 範例使用等待程式以等待培訓完成。等待程式是一種實用的程式方法,用於輪詢特定狀態發生。或者,您可以通過呼叫 DescribeProjectVersions
來獲取培訓的當前狀態。當 Status
的欄位值轉為 TRAINING_COMPLETED
時,表示培訓完成。培訓完成後,您可以檢閱評估結果來評估模型的品質。
培訓模型 (SDK)
以下範例顯示如何使用與專案相關的培訓和測試資料集來培訓模型。
培訓模型 (SDK)
-
如果您尚未這樣做,請安裝並設定 AWS CLI 和 AWS SDK。如需詳細資訊,請參閱 步驟 4:設定 AWS CLI 和 AWS SDKs。
使用下列範例程式碼來培訓專案。
- AWS CLI
-
以下範例會建立一個模型。培訓資料集會被分割以建立測試資料集。取代以下項目:
-
my_project_arn
與專案的 Amazon Resource Name (ARN)。
-
version_name
取代為您選擇的唯一版本名稱。
-
output_bucket
取代為 Amazon Rekognition 自訂標籤儲存培訓結果的 Amazon S3 儲存貯體的名稱。
-
output_folder
取代為儲存培訓結果的資料夾名稱。
(可選參數) --kms-key-id
包含您的 AWS 金鑰管理服務客戶主金鑰的識別碼。
aws rekognition create-project-version \
--project-arn project_arn
\
--version-name version_name
\
--output-config '{"S3Bucket":"output_bucket
", "S3KeyPrefix":"output_folder
"}' \
--profile custom-labels-access
- Python
-
以下範例會建立一個模型。提供下列命令列參數:
project_arn
– 專案的 Amazon Resource Name (ARN)。
version_name
– 您所選模型的唯一版本名稱。
output_bucket
– Amazon Rekognition 自訂標籤儲存培訓結果的 Amazon S3 儲存貯體的名稱。
output_folder
– 儲存培訓結果的資料夾名稱。
您也可選擇性地提供以下命令列參數,以將標籤附加至您的模型:
tag
– 您選擇需要附加至模型的標籤名稱。
tag_value
標籤值。
#Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#PDX-License-Identifier: MIT-0 (For details, see https://github.com/awsdocs/amazon-rekognition-custom-labels-developer-guide/blob/master/LICENSE-SAMPLECODE.)
import argparse
import logging
import json
import boto3
from botocore.exceptions import ClientError
logger = logging.getLogger(__name__)
def train_model(rek_client, project_arn, version_name, output_bucket, output_folder, tag_key, tag_key_value):
"""
Trains an Amazon Rekognition Custom Labels model.
:param rek_client: The Amazon Rekognition Custom Labels Boto3 client.
:param project_arn: The ARN of the project in which you want to train a model.
:param version_name: A version for the model.
:param output_bucket: The S3 bucket that hosts training output.
:param output_folder: The path for the training output within output_bucket
:param tag_key: The name of a tag to attach to the model. Pass None to exclude
:param tag_key_value: The value of the tag. Pass None to exclude
"""
try:
#Train the model
status=""
logger.info("training model version %s for project %s",
version_name, project_arn)
output_config = json.loads(
'{"S3Bucket": "'
+ output_bucket
+ '", "S3KeyPrefix": "'
+ output_folder
+ '" } '
)
tags={}
if tag_key is not None and tag_key_value is not None:
tags = json.loads(
'{"' + tag_key + '":"' + tag_key_value + '"}'
)
response=rek_client.create_project_version(
ProjectArn=project_arn,
VersionName=version_name,
OutputConfig=output_config,
Tags=tags
)
logger.info("Started training: %s", response['ProjectVersionArn'])
# Wait for the project version training to complete.
project_version_training_completed_waiter = rek_client.get_waiter('project_version_training_completed')
project_version_training_completed_waiter.wait(ProjectArn=project_arn,
VersionNames=[version_name])
# Get the completion status.
describe_response=rek_client.describe_project_versions(ProjectArn=project_arn,
VersionNames=[version_name])
for model in describe_response['ProjectVersionDescriptions']:
logger.info("Status: %s", model['Status'])
logger.info("Message: %s", model['StatusMessage'])
status=model['Status']
logger.info("finished training")
return response['ProjectVersionArn'], status
except ClientError as err:
logger.exception("Couldn't create model: %s", err.response['Error']['Message'] )
raise
def add_arguments(parser):
"""
Adds command line arguments to the parser.
:param parser: The command line parser.
"""
parser.add_argument(
"project_arn", help="The ARN of the project in which you want to train a model"
)
parser.add_argument(
"version_name", help="A version name of your choosing."
)
parser.add_argument(
"output_bucket", help="The S3 bucket that receives the training results."
)
parser.add_argument(
"output_folder", help="The folder in the S3 bucket where training results are stored."
)
parser.add_argument(
"--tag_name", help="The name of a tag to attach to the model", required=False
)
parser.add_argument(
"--tag_value", help="The value for the tag.", required=False
)
def main():
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
try:
# Get command line arguments.
parser = argparse.ArgumentParser(usage=argparse.SUPPRESS)
add_arguments(parser)
args = parser.parse_args()
print(f"Training model version {args.version_name} for project {args.project_arn}")
# Train the model.
session = boto3.Session(profile_name='custom-labels-access')
rekognition_client = session.client("rekognition")
model_arn, status=train_model(rekognition_client,
args.project_arn,
args.version_name,
args.output_bucket,
args.output_folder,
args.tag_name,
args.tag_value)
print(f"Finished training model: {model_arn}")
print(f"Status: {status}")
except ClientError as err:
logger.exception("Problem training model: %s", err)
print(f"Problem training model: {err}")
except Exception as err:
logger.exception("Problem training model: %s", err)
print(f"Problem training model: {err}")
if __name__ == "__main__":
main()
- Java V2
-
以下範例會培訓模型。提供下列命令列參數:
project_arn
– 專案的 Amazon Resource Name (ARN)。
version_name
– 您所選模型的唯一版本名稱。
output_bucket
– Amazon Rekognition 自訂標籤儲存培訓結果的 Amazon S3 儲存貯體的名稱。
output_folder
– 儲存培訓結果的資料夾名稱。
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package com.example.rekognition;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.waiters.WaiterResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rekognition.RekognitionClient;
import software.amazon.awssdk.services.rekognition.model.CreateProjectVersionRequest;
import software.amazon.awssdk.services.rekognition.model.CreateProjectVersionResponse;
import software.amazon.awssdk.services.rekognition.model.DescribeProjectVersionsRequest;
import software.amazon.awssdk.services.rekognition.model.DescribeProjectVersionsResponse;
import software.amazon.awssdk.services.rekognition.model.OutputConfig;
import software.amazon.awssdk.services.rekognition.model.ProjectVersionDescription;
import software.amazon.awssdk.services.rekognition.model.RekognitionException;
import software.amazon.awssdk.services.rekognition.waiters.RekognitionWaiter;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
public class TrainModel {
public static final Logger logger = Logger.getLogger(TrainModel.class.getName());
public static String trainMyModel(RekognitionClient rekClient, String projectArn, String versionName,
String outputBucket, String outputFolder) {
try {
OutputConfig outputConfig = OutputConfig.builder().s3Bucket(outputBucket).s3KeyPrefix(outputFolder).build();
logger.log(Level.INFO, "Training Model for project {0}", projectArn);
CreateProjectVersionRequest createProjectVersionRequest = CreateProjectVersionRequest.builder()
.projectArn(projectArn).versionName(versionName).outputConfig(outputConfig).build();
CreateProjectVersionResponse response = rekClient.createProjectVersion(createProjectVersionRequest);
logger.log(Level.INFO, "Model ARN: {0}", response.projectVersionArn());
logger.log(Level.INFO, "Training model...");
// wait until training completes
DescribeProjectVersionsRequest describeProjectVersionsRequest = DescribeProjectVersionsRequest.builder()
.versionNames(versionName)
.projectArn(projectArn)
.build();
RekognitionWaiter waiter = rekClient.waiter();
WaiterResponse<DescribeProjectVersionsResponse> waiterResponse = waiter
.waitUntilProjectVersionTrainingCompleted(describeProjectVersionsRequest);
Optional<DescribeProjectVersionsResponse> optionalResponse = waiterResponse.matched().response();
DescribeProjectVersionsResponse describeProjectVersionsResponse = optionalResponse.get();
for (ProjectVersionDescription projectVersionDescription : describeProjectVersionsResponse
.projectVersionDescriptions()) {
System.out.println("ARN: " + projectVersionDescription.projectVersionArn());
System.out.println("Status: " + projectVersionDescription.statusAsString());
System.out.println("Message: " + projectVersionDescription.statusMessage());
}
return response.projectVersionArn();
} catch (RekognitionException e) {
logger.log(Level.SEVERE, "Could not train model: {0}", e.getMessage());
throw e;
}
}
public static void main(String args[]) {
String versionName = null;
String projectArn = null;
String projectVersionArn = null;
String bucket = null;
String location = null;
final String USAGE = "\n" + "Usage: " + "<project_name> <version_name> <output_bucket> <output_folder>\n\n" + "Where:\n"
+ " project_arn - The ARN of the project that you want to use. \n\n"
+ " version_name - A version name for the model.\n\n"
+ " output_bucket - The S3 bucket in which to place the training output. \n\n"
+ " output_folder - The folder within the bucket that the training output is stored in. \n\n";
if (args.length != 4) {
System.out.println(USAGE);
System.exit(1);
}
projectArn = args[0];
versionName = args[1];
bucket = args[2];
location = args[3];
try {
// Get the Rekognition client.
RekognitionClient rekClient = RekognitionClient.builder()
.credentialsProvider(ProfileCredentialsProvider.create("custom-labels-access"))
.region(Region.US_WEST_2)
.build();
// Train model
projectVersionArn = trainMyModel(rekClient, projectArn, versionName, bucket, location);
System.out.println(String.format("Created model: %s for Project ARN: %s", projectVersionArn, projectArn));
rekClient.close();
} catch (RekognitionException rekError) {
logger.log(Level.SEVERE, "Rekognition client error: {0}", rekError.getMessage());
System.exit(1);
}
}
}
如果培訓失敗,請參閱 偵錯失敗的模型訓練。