編譯模型 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

編譯模型

滿足先決條件 後,您可以使用 Amazon SageMaker Neo 編譯模型。您可以使用 AWS CLI、主控台或 Amazon Web Services SDK for Python (Boto3) 編譯模型,請參閱使用 Neo 編譯模型 。在這個範例中,你會用 Boto3 編譯你的模型。

若要編譯模型, SageMaker Neo 需要下列資訊:

  1. 存放訓練模型的 Amazon S3 儲存貯URI體。

    如果您遵循先決條件,則儲存貯體的名稱會儲存在名為 bucket 的變數中。下列程式碼片段顯示如何使用 AWS CLI列出所有您的儲存貯體:

    aws s3 ls

    例如:

    $ aws s3 ls 2020-11-02 17:08:50 bucket
  2. URI您要儲存編譯模型的 Amazon S3 儲存貯體。

    下列程式碼片段會將 Amazon S3 儲存貯體URI與名為 的輸出目錄名稱串連output

    s3_output_location = f's3://{bucket}/output'
  3. 您用來訓練模型的機器學習架構。

    定義您用來訓練模型的架構。

    framework = 'framework-name'

    例如,如果您想要編譯使用 訓練的模型 TensorFlow,您可以使用 tflitetensorflowtflite 如果您想要使用較輕版本的 ,且使用較少的儲存記憶體 TensorFlow ,請使用 。

    framework = 'tflite'

    有關 Neo 支援的架構之完整清單,請參閱支援的架構、裝置、系統和架構

  4. 模型輸入的形狀。

    Neo 需要輸入張量的名稱和形狀。名稱和形狀會以鍵值對的形式傳遞。value 是輸入張量的整數維度清單,key 是模型中輸入張量的確切名稱。

    data_shape = '{"name": [tensor-shape]}'

    例如:

    data_shape = '{"normalized_input_image_tensor":[1, 300, 300, 3]}'
    注意

    取決於您使用的架構,請確保模型格式正確。請參閱 SageMaker Neo 預期哪些輸入資料形狀? 此字典中的金鑰必須變更為新的輸入張量名稱。

  5. 要編譯的目標裝置名稱或硬體平台的一般詳細資訊

    target_device = 'target-device-name'

    例如,如果您想要部署到 Raspberry Pi 3,請使用:

    target_device = 'rasp3b'

    您可以在支援的架構、裝置、系統和架構中找到系統支援的邊緣裝置完整清單。

現在您已完成前面的步驟,可以將編譯任務提交給 Neo。

# Create a SageMaker client so you can submit a compilation job sagemaker_client = boto3.client('sagemaker', region_name=AWS_REGION) # Give your compilation job a name compilation_job_name = 'getting-started-demo' print(f'Compilation job for {compilation_job_name} started') response = sagemaker_client.create_compilation_job( CompilationJobName=compilation_job_name, RoleArn=role_arn, InputConfig={ 'S3Uri': s3_input_location, 'DataInputConfig': data_shape, 'Framework': framework.upper() }, OutputConfig={ 'S3OutputLocation': s3_output_location, 'TargetDevice': target_device }, StoppingCondition={ 'MaxRuntimeInSeconds': 900 } ) # Optional - Poll every 30 sec to check completion status import time while True: response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name) if response['CompilationJobStatus'] == 'COMPLETED': break elif response['CompilationJobStatus'] == 'FAILED': raise RuntimeError('Compilation failed') print('Compiling ...') time.sleep(30) print('Done!')

如果您想要偵錯的其他資訊,請包含下列列印陳述式:

print(response)

如果編譯任務成功,編譯過的模型會儲存在先前指定的輸出 Amazon S3 儲存貯體中 (s3_output_location)。在本機下載已編譯的模型:

object_path = f'output/{model}-{target_device}.tar.gz' neo_compiled_model = f'compiled-{model}.tar.gz' s3_client.download_file(bucket, object_path, neo_compiled_model)