Skip to content

Commit 299fbbc

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Fix check_training_mode in symbolic_helper (pytorch#78376)
`check_training_mode` always warned that an op is set to training because it was comparing an int `op_train_mode` with an Enum `GLOBALS.training_mode`. This PR fixes the behavior. Pull Request resolved: pytorch#78376 Approved by: https://github.com/garymm
1 parent dfd78bf commit 299fbbc

File tree

2 files changed

+94
-21
lines changed

2 files changed

+94
-21
lines changed

Diff for: test/onnx/test_symbolic_helper.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Owner(s): ["module: onnx"]
2+
"""Unit tests on `torch.onnx.symbolic_helper`."""
3+
4+
import torch
5+
from torch.onnx import symbolic_helper
6+
from torch.onnx._globals import GLOBALS
7+
from torch.testing._internal import common_utils
8+
9+
10+
class TestHelperFunctions(common_utils.TestCase):
11+
def setUp(self):
12+
super().setUp()
13+
self._initial_training_mode = GLOBALS.training_mode
14+
15+
def tearDown(self):
16+
GLOBALS.training_mode = self._initial_training_mode
17+
18+
@common_utils.parametrize(
19+
"op_train_mode,export_mode",
20+
[
21+
common_utils.subtest(
22+
[1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
23+
),
24+
common_utils.subtest(
25+
[0, torch.onnx.TrainingMode.EVAL],
26+
name="modes_match_op_train_mode_0_export_mode_eval",
27+
),
28+
common_utils.subtest(
29+
[1, torch.onnx.TrainingMode.TRAINING],
30+
name="modes_match_op_train_mode_1_export_mode_training",
31+
),
32+
],
33+
)
34+
def test_check_training_mode_does_not_warn_when(
35+
self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
36+
):
37+
GLOBALS.training_mode = export_mode
38+
self.assertNotWarn(
39+
lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
40+
)
41+
42+
@common_utils.parametrize(
43+
"op_train_mode,export_mode",
44+
[
45+
common_utils.subtest(
46+
[0, torch.onnx.TrainingMode.TRAINING],
47+
name="modes_do_not_match_op_train_mode_0_export_mode_training",
48+
),
49+
common_utils.subtest(
50+
[1, torch.onnx.TrainingMode.EVAL],
51+
name="modes_do_not_match_op_train_mode_1_export_mode_eval",
52+
),
53+
],
54+
)
55+
def test_check_training_mode_warns_when(
56+
self,
57+
op_train_mode: int,
58+
export_mode: torch.onnx.TrainingMode,
59+
):
60+
with self.assertWarnsRegex(
61+
UserWarning, f"ONNX export mode is set to {export_mode}"
62+
):
63+
GLOBALS.training_mode = export_mode
64+
symbolic_helper.check_training_mode(op_train_mode, "testop")
65+
66+
67+
common_utils.instantiate_parametrized_tests(TestHelperFunctions)
68+
69+
70+
if __name__ == "__main__":
71+
common_utils.run_tests()

Diff for: torch/onnx/symbolic_helper.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -1114,27 +1114,29 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
11141114
return padding
11151115

11161116

1117-
def check_training_mode(op_train_mode, op_name):
1118-
op_train_mode = True if op_train_mode == 1 else False
1119-
if GLOBALS.training_mode is not None and op_train_mode != GLOBALS.training_mode:
1120-
op_mode = "training " if op_train_mode else "inference"
1121-
training_mode = "training " if GLOBALS.training_mode else "inference"
1122-
# setting the model mode could result in op_mode != _flags.training_mode
1123-
# if the model is a FuncModule. In this case we warn the user of
1124-
# the state and export depending on op_mode
1125-
# This is to support use-cases of fixing certain layer weights
1126-
# in training.
1127-
warnings.warn(
1128-
"ONNX export mode is set to "
1129-
+ training_mode
1130-
+ " mode, but operator "
1131-
+ op_name
1132-
+ " is set to "
1133-
+ op_mode
1134-
+ " mode. The operators will be exported in "
1135-
+ op_mode
1136-
+ ", as specified by the functional operator."
1137-
)
1117+
def check_training_mode(op_train_mode: int, op_name: str) -> None:
1118+
"""Warns the user if the model's training mode and the export mode do not agree."""
1119+
if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
1120+
return
1121+
1122+
if op_train_mode:
1123+
op_mode_enum = _C_onnx.TrainingMode.TRAINING
1124+
else:
1125+
op_mode_enum = _C_onnx.TrainingMode.EVAL
1126+
if op_mode_enum == GLOBALS.training_mode:
1127+
# The modes agree. Do nothing
1128+
return
1129+
1130+
op_mode_text = f"train={bool(op_train_mode)}"
1131+
# Setting the model mode could result in op_mode != GLOBALS.training_mode
1132+
# if the model is a FuncModule. In this case we warn the user of
1133+
# the state and export depending on op_mode
1134+
# This is to support use-cases of fixing certain layer weights
1135+
# in training.
1136+
warnings.warn(
1137+
f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' "
1138+
f"is set to {op_mode_text}. Exporting with {op_mode_text}."
1139+
)
11381140

11391141

11401142
def _flatten_helper(g, input, start_dim, end_dim, dim):

0 commit comments

Comments
 (0)