本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
編譯模型
滿足先決條件 後,您可以使用 Amazon SageMaker Neo 編譯模型。您可以使用 AWS CLI、主控台或 Amazon Web Services SDK for Python (Boto3)
若要編譯模型, SageMaker Neo 需要下列資訊:
-
存放訓練模型的 Amazon S3 儲存貯URI體。
如果您遵循先決條件,則儲存貯體的名稱會儲存在名為
bucket
的變數中。下列程式碼片段顯示如何使用 AWS CLI列出所有您的儲存貯體:aws s3 ls
例如:
$ aws s3 ls 2020-11-02 17:08:50 bucket
-
URI您要儲存編譯模型的 Amazon S3 儲存貯體。
下列程式碼片段會將 Amazon S3 儲存貯體URI與名為 的輸出目錄名稱串連
output
:s3_output_location = f's3://{bucket}/output'
-
您用來訓練模型的機器學習架構。
定義您用來訓練模型的架構。
framework = 'framework-name'
例如,如果您想要編譯使用 訓練的模型 TensorFlow,您可以使用
tflite
或tensorflow
。tflite
如果您想要使用較輕版本的 ,且使用較少的儲存記憶體 TensorFlow ,請使用 。framework = 'tflite'
有關 Neo 支援的架構之完整清單,請參閱支援的架構、裝置、系統和架構。
-
模型輸入的形狀。
Neo 需要輸入張量的名稱和形狀。名稱和形狀會以鍵值對的形式傳遞。
value
是輸入張量的整數維度清單,key
是模型中輸入張量的確切名稱。data_shape = '{"name": [tensor-shape]}'
例如:
data_shape = '{"normalized_input_image_tensor":[1, 300, 300, 3]}'
注意
取決於您使用的架構,請確保模型格式正確。請參閱 SageMaker Neo 預期哪些輸入資料形狀? 此字典中的金鑰必須變更為新的輸入張量名稱。
-
要編譯的目標裝置名稱或硬體平台的一般詳細資訊
target_device =
'target-device-name'
例如,如果您想要部署到 Raspberry Pi 3,請使用:
target_device = 'rasp3b'
您可以在支援的架構、裝置、系統和架構中找到系統支援的邊緣裝置完整清單。
現在您已完成前面的步驟,可以將編譯任務提交給 Neo。
# Create a SageMaker client so you can submit a compilation job sagemaker_client = boto3.client('sagemaker', region_name=AWS_REGION) # Give your compilation job a name compilation_job_name = 'getting-started-demo' print(f'Compilation job for {compilation_job_name} started') response = sagemaker_client.create_compilation_job( CompilationJobName=compilation_job_name, RoleArn=role_arn, InputConfig={ 'S3Uri': s3_input_location, 'DataInputConfig': data_shape, 'Framework': framework.upper() }, OutputConfig={ 'S3OutputLocation': s3_output_location, 'TargetDevice': target_device }, StoppingCondition={ 'MaxRuntimeInSeconds': 900 } ) # Optional - Poll every 30 sec to check completion status import time while True: response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name) if response['CompilationJobStatus'] == 'COMPLETED': break elif response['CompilationJobStatus'] == 'FAILED': raise RuntimeError('Compilation failed') print('Compiling ...') time.sleep(30) print('Done!')
如果您想要偵錯的其他資訊,請包含下列列印陳述式:
print(response)
如果編譯任務成功,編譯過的模型會儲存在先前指定的輸出 Amazon S3 儲存貯體中 (s3_output_location
)。在本機下載已編譯的模型:
object_path = f'output/{model}-{target_device}.tar.gz' neo_compiled_model = f'compiled-{model}.tar.gz' s3_client.download_file(bucket, object_path, neo_compiled_model)