Skip to content

Commit 16889b0

Browse files
authored
Update verifier
Differential Revision: D68839524 Pull Request resolved: #8034
1 parent c955969 commit 16889b0

File tree

3 files changed

+58
-51
lines changed

3 files changed

+58
-51
lines changed

exir/program/test/test_program.py

+39
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,45 @@ def forward(self, x, y):
313313
)
314314
edge_manager.to_executorch()
315315

316+
def test_data_dependent(self):
317+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
318+
torch.library.define(
319+
"mylib::foo1",
320+
"(Tensor a, Tensor b) -> Tensor",
321+
tags=torch.Tag.pt2_compliant_tag,
322+
lib=lib,
323+
)
324+
325+
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
326+
def foo_impl(a, b):
327+
return a + b
328+
329+
@torch.library.register_fake("mylib::foo1", lib=lib)
330+
def mylib_foo_default_fake(*args, **kwargs):
331+
ctx = torch.library.get_ctx()
332+
fake_shape = ctx.new_dynamic_size()
333+
return torch.empty(fake_shape, dtype=torch.float32, device="cpu")
334+
335+
class M(torch.nn.Module):
336+
def forward(self, a, b, c):
337+
res = torch.ops.mylib.foo1(a, b)
338+
339+
c_item = c.item()
340+
torch._check_is_size(c_item)
341+
torch._check(c_item < res.shape[0])
342+
return res[:c_item]
343+
344+
inp = (torch.randn(10), torch.randn(10), torch.tensor(3))
345+
346+
ep = export(M(), inp)
347+
edge = to_edge(ep)
348+
self.assertTrue(
349+
torch.allclose(
350+
edge.exported_program().module()(*inp),
351+
M()(*inp),
352+
)
353+
)
354+
316355
def test_edge_manager_transform(self):
317356
edge_manager: EdgeProgramManager = to_edge(
318357
get_exported_programs(), get_config_methods()

exir/verification/test/test_verifier.py

-34
Original file line numberDiff line numberDiff line change
@@ -36,40 +36,6 @@ def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None:
3636
verifier.check_valid_edge_op(edge_op)
3737
verifier.check_valid_op(edge_op)
3838

39-
def test_edge_verifier_enablement(self) -> None:
40-
class M(torch.nn.Module):
41-
def forward(self, x, y):
42-
z = y.item()
43-
torch._check(z > 0)
44-
torch._check(z < 4)
45-
return x[z : z + y.shape[0]]
46-
47-
ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])), strict=True)
48-
49-
compile_config_with_disable_ir_validity = EdgeCompileConfig(
50-
_check_ir_validity=False
51-
)
52-
edge_manager = to_edge(
53-
ep, compile_config=compile_config_with_disable_ir_validity
54-
)
55-
56-
normal_verifier = EXIREdgeDialectVerifier()
57-
disable_ir_validity_verifier = EXIREdgeDialectVerifier(
58-
compile_config_with_disable_ir_validity
59-
)
60-
61-
# exported model can not pass normal verifier due to
62-
# aten.sym_constrain_range.default is illegal to be edge op
63-
with self.assertRaises(SpecViolationError):
64-
normal_verifier(edge_manager.exported_program())
65-
66-
# exported model can pass disable_ir_validity_verifier due to verifier
67-
# is disabled by compile_config_with_disable_ir_validity
68-
# (_check_ir_validity=False). Noted that this verifation has been done
69-
# when calling `to_edge`. Explicitly calling verifier here just for better
70-
# demonstration and is unnecessary in real world for ir verification.
71-
disable_ir_validity_verifier(edge_manager.exported_program())
72-
7339
def test_edge_verifier_check_edge_op(self) -> None:
7440
class Model(torch.nn.Module):
7541
def __init__(self):

exir/verification/verifier.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from executorch.exir.error import ExportError, ExportErrorType
1717
from executorch.exir.lowered_backend_module import LoweredBackendModule
1818
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
19+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
20+
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
1921
from executorch.exir.verification.arg_validator import (
2022
EdgeOpArgValidator,
2123
RunHigherOrderOperatorError,
@@ -99,16 +101,20 @@ def __init__(self) -> None:
99101
self._exception_list = exception_list if exception_list else []
100102

101103
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
102-
exception_list = [
103-
torch.ops.aten.mkldnn_rnn_layer.default,
104-
torch.ops.aten._upsample_bilinear2d_aa.default,
105-
torch.ops.aten.quantize_per_tensor.default,
106-
torch.ops.aten.dequantize.self,
107-
torch.ops.aten.max.default, # TODO(T188268054)
108-
torch.ops.aten.min.default, # TODO(T188268054)
109-
torch.ops.aten.full_like.default, # TODO(T183507359)
110-
]
111-
exception_list += self._exception_list
104+
exception_list = (
105+
[
106+
torch.ops.aten.mkldnn_rnn_layer.default,
107+
torch.ops.aten._upsample_bilinear2d_aa.default,
108+
torch.ops.aten.quantize_per_tensor.default,
109+
torch.ops.aten.dequantize.self,
110+
torch.ops.aten.max.default, # TODO(T188268054)
111+
torch.ops.aten.min.default, # TODO(T188268054)
112+
torch.ops.aten.full_like.default, # TODO(T183507359)
113+
]
114+
+ list(_EXECUTORCH_SYM_OPS)
115+
+ DISALLOW_LIST
116+
+ self._exception_list
117+
)
112118

113119
return exception_list
114120

@@ -249,13 +255,9 @@ def check_valid_edge_op(self, op):
249255
return
250256
if (
251257
op
252-
in [
253-
operator.getitem,
254-
torch.ops.aten.sym_size.int,
255-
torch.ops.aten.scalar_tensor.default,
256-
torch.ops.aten._assert_async.msg,
257-
torch.ops.aten._assert_scalar.default,
258-
]
258+
in [operator.getitem]
259+
+ DISALLOW_LIST
260+
+ list(_EXECUTORCH_SYM_OPS)
259261
+ self._exception_list
260262
):
261263
return

0 commit comments

Comments
 (0)