Tensorflow redundant softmax Medium

Computing the cross entropy loss directly from logits using the softmax_cross_entropy_with_logits is numerically more stable than computing a softmax and then the cross entropy. The improvement comes from the internal use of the log-sum-exp trick.

Detector ID
python/tensorflow-redundant-softmax@v1.0
Category
Common Weakness Enumeration (CWE) external icon
-

Noncompliant example

1def tensorflow_redundant_softmax_noncompliant():
2    import tensorflow as tf
3    logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
4    labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
5    # Noncompliant: using `tf.nn.softmax` with
6    # `tf.nn.softmax_cross_entropy_with_logits` is redundant.
7    tf.nn.softmax_cross_entropy_with_logits(
8     labels=labels, logits=tf.nn.softmax(logits))

Compliant example

1def tensorflow_redundant_softmax_compliant():
2    import tensorflow as tf
3    logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
4    labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
5    # Compliant: unscaled `logits` is passed directly
6    # to `tf.nn.softmax_cross_entropy_with_logits`.
7    tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)