Pytorch miss call to eval Medium

Checks if eval() is called before validating or testing a model. Some layers behave differently during training and evaluation.

Detector ID
python/pytorch-miss-call-to-eval@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def pytorch_miss_call_to_eval_noncompliant(model):
2    import torch
3    # Noncompliant: miss call to `eval()` after load.
4    model.load_state_dict(torch.load("model.pth"))

Compliant example

1def pytorch_miss_call_to_eval_compliant(model):
2    model.load_state_dict(torch.load("model.pth"))
3    # Compliant: `eval()` is called after load.
4    model.eval()