Pytorch disable gradient calculation Medium

Checks if gradient calculation is disabled during evaluation.

Detector ID
python/pytorch-disable-gradient-calculation@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def disable_gradient_calculation_noncompliant():
2    import torch
3    # Noncompliant: disables gradient calculation using `torch.no_grad()`.
4    with torch.no_grad():
5        model.eval()

Compliant example

1def disable_gradient_calculation_compliant():
2    import torch
3    # Compliant: disables gradient calculation using `torch.inference_mode()`.
4    with torch.inference_mode():
5        model.eval()