本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
编译模型
满足先决条件后,即可使用 Amazon SageMaker Neo 编译模型。你可以使用 AWS CLI、控制台或SDK适用于 Python 的 Amazon Web Services (Boto3)
要编译模型, SageMaker Neo 需要以下信息:
-
您存储训练过的模型的 Amazon S3 存储桶URI。
如果您符合先决条件,则存储桶的名称将存储在名为
bucket
的变量中。以下代码段显示如何使用 AWS CLI列出所有存储桶:aws s3 ls
例如:
$ aws s3 ls 2020-11-02 17:08:50 bucket
-
您要保存已编译模型的 Amazon S3 存储桶URI。
以下代码段将您的 Amazon S3 存储桶URI与名为的输出目录的名称连接在一起:
output
s3_output_location = f's3://{bucket}/output'
-
用于训练模型的机器学习框架。
定义用于训练模型的框架。
framework = 'framework-name'
例如,如果要编译使用训练过的模型 TensorFlow,则可以使用
tflite
或tensorflow
。tflite
如果您想使用占用较少存储内存的较轻版本 TensorFlow ,请使用。framework = 'tflite'
有关 NEO 所支持框架的完整列表,请参阅支持的框架、设备、系统和架构。
-
模型输入的形状。
Neo 需要输入张量的名称和形状。名称和形状以键值对的形式传入。
value
是输入张量的整数维度的列表,key
是模型中输入张量的确切名称。data_shape = '{"name": [tensor-shape]}'
例如:
data_shape = '{"normalized_input_image_tensor":[1, 300, 300, 3]}'
注意
确保模型的格式正确,具体取决于您使用的框架。请参阅 SageMaker Neo 期望什么输入数据形状? 此字典中的键必须更改为新输入张量的名称。
-
要编译的目标设备的名称或硬件平台的一般详细信息
target_device =
'target-device-name'
例如,如果要部署到 Raspberry Pi 3,请使用:
target_device = 'rasp3b'
您可以在支持的框架、设备、系统和架构中找到支持的边缘设备的完整列表。
现在您已完成前面的步骤,可以向 Neo 提交编译作业。
# Create a SageMaker client so you can submit a compilation job sagemaker_client = boto3.client('sagemaker', region_name=AWS_REGION) # Give your compilation job a name compilation_job_name = 'getting-started-demo' print(f'Compilation job for {compilation_job_name} started') response = sagemaker_client.create_compilation_job( CompilationJobName=compilation_job_name, RoleArn=role_arn, InputConfig={ 'S3Uri': s3_input_location, 'DataInputConfig': data_shape, 'Framework': framework.upper() }, OutputConfig={ 'S3OutputLocation': s3_output_location, 'TargetDevice': target_device }, StoppingCondition={ 'MaxRuntimeInSeconds': 900 } ) # Optional - Poll every 30 sec to check completion status import time while True: response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name) if response['CompilationJobStatus'] == 'COMPLETED': break elif response['CompilationJobStatus'] == 'FAILED': raise RuntimeError('Compilation failed') print('Compiling ...') time.sleep(30) print('Done!')
如果需要更多信息以进行调试,请包括以下打印语句:
print(response)
如果编译作业成功,则编译过的模型存储在您之前指定的输出 Amazon S3 存储桶中 (s3_output_location
)。将编译过的模型下载到本地:
object_path = f'output/{model}-{target_device}.tar.gz' neo_compiled_model = f'compiled-{model}.tar.gz' s3_client.download_file(bucket, object_path, neo_compiled_model)