The computation of the bceloss using sigmoid values as inputs can be replaced by a single BCEWithLogitsLoss. By combining these two operations, Pytorch can take advantage of the log-sum-exp trick which offers better numerical stability.
1def pytorch_sigmoid_before_bceloss_noncompliant():
2 import torch
3 import torch.nn as nn
4 # Noncompliant: `Sigmoid` layer followed by `BCELoss`
5 # is not numerically robust.
6 m = nn.Sigmoid()
7 loss = nn.BCELoss()
8
9 input = torch.randn(3, requires_grad=True)
10 target = torch.empty(3).random_(2)
11
12 output = loss(m(input), target)
13 output.backward()
1def pytorch_sigmoid_before_bceloss_compliant():
2 import torch
3 import torch.nn as nn
4 # Compliant: `BCEWithLogitsLoss` function integrates a `Sigmoid`
5 # layer and the `BCELoss` into one class
6 # and is numerically robust.
7 loss = nn.BCEWithLogitsLoss()
8
9 input = torch.randn(3, requires_grad=True)
10 target = torch.empty(3).random_(2)
11
12 output = loss(input, target)
13 output.backward()