Checks if eval() is called before validating or testing a model. Some layers behave differently during training and evaluation.
1def pytorch_miss_call_to_eval_noncompliant(model):
2 import torch
3 # Noncompliant: miss call to `eval()` after load.
4 model.load_state_dict(torch.load("model.pth"))
1def pytorch_miss_call_to_eval_compliant(model):
2 model.load_state_dict(torch.load("model.pth"))
3 # Compliant: `eval()` is called after load.
4 model.eval()