Skip to content

Commit

Permalink
Add limited SoftmaxCrossEntropyLoss support
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Sparling <[email protected]>
  • Loading branch information
Jonathan Sparling committed Sep 1, 2022
1 parent f9ebc35 commit b876a41
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
32 changes: 32 additions & 0 deletions onnx_tf/handlers/backend/softmax_cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import tensorflow as tf

from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func


@onnx_op("SoftmaxCrossEntropyLoss")
@tf_func(tf.nn.sparse_softmax_cross_entropy_with_logits)
class SoftmaxCrossEntropyLoss(BackendHandler):
@classmethod
def _common(cls, node, **kwargs):
logits = kwargs["tensor_dict"][node.inputs[0]]
labels = kwargs["tensor_dict"][node.inputs[1]]

labels_shape = tf.shape(labels)
if labels_shape.shape[0] > 1:
raise NotImplementedError(
"SoftmaxCrossEntropyLoss support is limited to rank 1 label tensors."
.format(spatial_size))

return [
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
]

@classmethod
def version_12(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_13(cls, node, **kwargs):
return cls._common(node, **kwargs)
10 changes: 10 additions & 0 deletions test/backend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3987,6 +3987,16 @@ def test_softplus(self):
np.log(np.exp(x) + 1),
decimal=5)

def test_softmax_cross_entropy_loss(self):
node_def = helper.make_node("SoftmaxCrossEntropyLoss", ["X", "Y"], ["Z"])
classes = 10
x = self._get_rnd_float32(shape=[1,classes])
y = self._get_rnd_int(0, classes-1, [1], np.int32)
output = run_node(node_def, [x, y])
np.testing.assert_almost_equal(output["Z"],
-np.log(np.exp(x)[0][y]/np.sum(np.exp(x))),
decimal=5)

def test_softsign(self):
node_def = helper.make_node("Softsign", ["X"], ["Y"])
x = self._get_rnd_float32(shape=[3, 4, 5])
Expand Down

0 comments on commit b876a41

Please sign in to comment.