使用smdebug客户端库创建自定义规则作为 Python 脚本 - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用smdebug客户端库创建自定义规则作为 Python 脚本

smdebug规则API提供了一个界面来设置您自己的自定义规则。以下 Python 脚本示例演示了如何构造自定义规则 CustomGradientRule。本教程的自定义规则监控梯度变是否太大并将默认阈值设置为 10。自定义规则采用 SageMaker 估算器在启动训练作业时创建的基本试验。

from smdebug.rules.rule import Rule class CustomGradientRule(Rule): def __init__(self, base_trial, threshold=10.0): super().__init__(base_trial) self.threshold = float(threshold) def invoke_at_step(self, step): for tname in self.base_trial.tensor_names(collection="gradients"): t = self.base_trial.tensor(tname) abs_mean = t.reduction_value(step, "mean", abs=True) if abs_mean > self.threshold: return True return False

您可以在同一个 python 脚本中按需要添加任意数量的自定义规则类,并通过在下个部分中构造自定义规则对象,来将它们部署到任何训练作业试验中。