Pytorch avoid softmax with nllloss Medium

NLLoss requires as input log-probabilities and therefore it is not compatible with the outputs of a Softmax layer which produces probabilities. Consider using a LogSoftmaxinstead, or the CrossEntropyLoss with logits.

Detector ID
python/pytorch-avoid-softmax-with-nllloss-rule@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def pytorch_avoid_softmax_with_nllloss_rule_noncompliant():
2    import math
3    import torch
4    import torch.nn as nn
5    # Noncompliant: `softmax` output is used directly with `NLLLoss`.
6    m = nn.functional.softmax(dim=1)
7    loss = nn.NLLLoss()
8    input = torch.randn(3, 5, requires_grad=True)
9    target = torch.tensor([1, 0, 4])
10    output = loss(m(input), target)

Compliant example

1def pytorch_avoid_softmax_with_nllloss_rule_compliant():
2    import math
3    import torch
4    import torch.nn as nn
5    # Compliant: `LogSoftmax` is used with `NLLLoss`.
6    m = nn.LogSoftmax(dim=1)
7    loss = nn.NLLLoss()
8    input = torch.randn(3, 5, requires_grad=True)
9    target = torch.tensor([1, 0, 4])
10    output = loss(m(input), target)