Skip to content

Commit 97feafa

Browse files
baijumeswanicodemzs
authored andcommitted
Export the model with torch.no_grad() context (#6472)
1 parent 4c8bedc commit 97feafa

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

orttraining/orttraining/python/training/ortmodule.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -457,14 +457,15 @@ def _get_forward_graph(self, input_names, dynamic_axes, *inputs, **kwargs):
457457
# from onnxruntime.training import register_custom_ops_pytorch_exporter
458458
# register_custom_ops_pytorch_exporter.register_custom_op()
459459

460-
# Export torch.nn.Module to ONNX
461-
torch.onnx.export(self._original_module,
462-
sample_inputs_copy,
463-
f,
464-
input_names=input_names,
465-
opset_version=ONNX_OPSET_VERSION,
466-
do_constant_folding=False,
467-
training=torch.onnx.TrainingMode.TRAINING,
468-
dynamic_axes=dynamic_axes)
460+
with torch.no_grad():
461+
# Export torch.nn.Module to ONNX
462+
torch.onnx.export(self._original_module,
463+
sample_inputs_copy,
464+
f,
465+
input_names=input_names,
466+
opset_version=ONNX_OPSET_VERSION,
467+
do_constant_folding=False,
468+
training=torch.onnx.TrainingMode.TRAINING,
469+
dynamic_axes=dynamic_axes)
469470

470471
return onnx.load_model_from_string(f.getvalue())

orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# orttraining_test_ortmodule_api.py
44

55
import torch
6+
from transformers import AutoConfig, BertForSequenceClassification
67
import pytest
8+
from unittest.mock import patch
79

810
import onnxruntime
911
from onnxruntime.training import ORTModule
@@ -84,6 +86,33 @@ def forward(self, model_input, x=None, y=None, z=None):
8486
out = self.fc2(out)
8587
return out
8688

89+
def _get_bert_for_sequence_classification_model(device):
90+
"""Returns the BertForSequenceClassification pretrained model"""
91+
92+
config = AutoConfig.from_pretrained(
93+
"bert-base-uncased",
94+
num_labels=2,
95+
num_hidden_layers=1,
96+
output_attentions = False,
97+
output_hidden_states = False,
98+
)
99+
100+
model = BertForSequenceClassification.from_pretrained(
101+
"bert-base-uncased",
102+
config=config,
103+
).to(device)
104+
105+
return model
106+
107+
def _get_bert_for_sequence_classification_sample_data(device):
108+
"""Returns sample data to be used with BertForSequenceClassification model"""
109+
110+
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
111+
input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
112+
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
113+
114+
return input_ids, input_mask, labels
115+
87116
# ORTModule-API tests
88117

89118
def test_forward_call_single_positional_argument():
@@ -286,3 +315,33 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder
286315
module_gradient_graph_builder = model._module_gradient_graph_builder
287316
model(x)
288317
assert module_gradient_graph_builder != model._module_gradient_graph_builder
318+
319+
def test_gpu_reserved_memory_with_torch_no_grad():
320+
device = 'cuda'
321+
322+
# Create a model and get the memory_reserved when torch.no_grad has been enabled
323+
# before and after export
324+
model_with_no_grad = _get_bert_for_sequence_classification_model(device)
325+
x, y, z = _get_bert_for_sequence_classification_sample_data(device)
326+
327+
torch.cuda.empty_cache()
328+
model_with_no_grad = ORTModule(model_with_no_grad)
329+
mem_reserved_before_export = torch.cuda.memory_reserved(device)
330+
model_with_no_grad(x, y, None, None, None, None, z)
331+
mem_reserved_after_export_with_torch_no_grad = torch.cuda.memory_reserved(device)
332+
del model_with_no_grad
333+
torch.cuda.empty_cache()
334+
mem_reserved_after_cache_empty = torch.cuda.memory_reserved(device)
335+
assert mem_reserved_before_export == mem_reserved_after_cache_empty
336+
337+
# Create another model and get the memory_reserved when torch.no_grad has not been enabled
338+
# after export
339+
model_without_no_grad = _get_bert_for_sequence_classification_model(device)
340+
model_without_no_grad = ORTModule(model_without_no_grad)
341+
mem_reserved_after_export_without_torch_no_grad = 0
342+
with patch('torch.no_grad'):
343+
model_without_no_grad(x, y, None, None, None, None, z)
344+
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
345+
346+
assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad
347+
assert mem_reserved_before_export < mem_reserved_after_export_with_torch_no_grad

0 commit comments

Comments
 (0)