Skip to content

Commit b0a580a

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX] Export logical_not (pytorch#96315)
Fixes pytorch#95154 Pull Request resolved: pytorch#96315 Approved by: https://github.com/justinchuby
1 parent 5f89d14 commit b0a580a

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

Diff for: test/onnx/test_op_consistency.py

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def reason_flaky() -> str:
305305
TESTED_OPS: frozenset[str] = frozenset(
306306
[
307307
"ceil",
308+
"logical_not",
308309
"sqrt",
309310
"stft",
310311
"t",

Diff for: test/onnx/test_pytorch_onnx_onnxruntime.py

+18
Original file line numberDiff line numberDiff line change
@@ -4920,6 +4920,24 @@ def forward(self, x, y):
49204920
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
49214921
self.run_test(XorModel(), input_args=(x, y))
49224922

4923+
@skipIfUnsupportedMinOpsetVersion(9)
4924+
def test_logical_not(self):
4925+
class NotModel(torch.nn.Module):
4926+
def forward(self, x):
4927+
return torch.logical_not(x)
4928+
4929+
x = torch.randint(0, 2, (5, 5), dtype=torch.bool)
4930+
self.run_test(NotModel(), input_args=(x,))
4931+
4932+
x = torch.randint(10, (5, 5), dtype=torch.int32)
4933+
self.run_test(NotModel(), input_args=(x,))
4934+
4935+
x = torch.randint(10, (5, 5), dtype=torch.double)
4936+
self.run_test(NotModel(), input_args=(x,))
4937+
4938+
x = torch.randint(10, (2, 3, 5), dtype=torch.float32)
4939+
self.run_test(NotModel(), input_args=(x,))
4940+
49234941
@skipIfUnsupportedMinOpsetVersion(11) # float equal added after opset 11
49244942
def test_eq(self):
49254943
class EqualModel(torch.nn.Module):

Diff for: torch/onnx/symbolic_opset9.py

+7
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
"log1p",
141141
"log2",
142142
"logical_and",
143+
"logical_not",
143144
"logical_or",
144145
"logical_xor",
145146
"logsumexp",
@@ -2295,6 +2296,12 @@ def logical_xor(g: jit_utils.GraphContext, input, other):
22952296
return g.op("Xor", input, other)
22962297

22972298

2299+
@_onnx_symbolic("aten::logical_not")
2300+
@_beartype.beartype
2301+
def logical_not(g: jit_utils.GraphContext, input):
2302+
return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL))
2303+
2304+
22982305
@_onnx_symbolic("aten::__rshift_")
22992306
@_beartype.beartype
23002307
def __rshift_(g: jit_utils.GraphContext, self, other):

0 commit comments

Comments
 (0)