|
3 | 3 | # orttraining_test_ortmodule_api.py
|
4 | 4 |
|
5 | 5 | import torch
|
| 6 | +from transformers import AutoConfig, BertForSequenceClassification |
6 | 7 | import pytest
|
| 8 | +from unittest.mock import patch |
7 | 9 |
|
8 | 10 | import onnxruntime
|
9 | 11 | from onnxruntime.training import ORTModule
|
@@ -84,6 +86,33 @@ def forward(self, model_input, x=None, y=None, z=None):
|
84 | 86 | out = self.fc2(out)
|
85 | 87 | return out
|
86 | 88 |
|
| 89 | +def _get_bert_for_sequence_classification_model(device): |
| 90 | + """Returns the BertForSequenceClassification pretrained model""" |
| 91 | + |
| 92 | + config = AutoConfig.from_pretrained( |
| 93 | + "bert-base-uncased", |
| 94 | + num_labels=2, |
| 95 | + num_hidden_layers=1, |
| 96 | + output_attentions = False, |
| 97 | + output_hidden_states = False, |
| 98 | + ) |
| 99 | + |
| 100 | + model = BertForSequenceClassification.from_pretrained( |
| 101 | + "bert-base-uncased", |
| 102 | + config=config, |
| 103 | + ).to(device) |
| 104 | + |
| 105 | + return model |
| 106 | + |
| 107 | +def _get_bert_for_sequence_classification_sample_data(device): |
| 108 | + """Returns sample data to be used with BertForSequenceClassification model""" |
| 109 | + |
| 110 | + input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device) |
| 111 | + input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device) |
| 112 | + labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device) |
| 113 | + |
| 114 | + return input_ids, input_mask, labels |
| 115 | + |
87 | 116 | # ORTModule-API tests
|
88 | 117 |
|
89 | 118 | def test_forward_call_single_positional_argument():
|
@@ -286,3 +315,33 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder
|
286 | 315 | module_gradient_graph_builder = model._module_gradient_graph_builder
|
287 | 316 | model(x)
|
288 | 317 | assert module_gradient_graph_builder != model._module_gradient_graph_builder
|
| 318 | + |
| 319 | +def test_gpu_reserved_memory_with_torch_no_grad(): |
| 320 | + device = 'cuda' |
| 321 | + |
| 322 | + # Create a model and get the memory_reserved when torch.no_grad has been enabled |
| 323 | + # before and after export |
| 324 | + model_with_no_grad = _get_bert_for_sequence_classification_model(device) |
| 325 | + x, y, z = _get_bert_for_sequence_classification_sample_data(device) |
| 326 | + |
| 327 | + torch.cuda.empty_cache() |
| 328 | + model_with_no_grad = ORTModule(model_with_no_grad) |
| 329 | + mem_reserved_before_export = torch.cuda.memory_reserved(device) |
| 330 | + model_with_no_grad(x, y, None, None, None, None, z) |
| 331 | + mem_reserved_after_export_with_torch_no_grad = torch.cuda.memory_reserved(device) |
| 332 | + del model_with_no_grad |
| 333 | + torch.cuda.empty_cache() |
| 334 | + mem_reserved_after_cache_empty = torch.cuda.memory_reserved(device) |
| 335 | + assert mem_reserved_before_export == mem_reserved_after_cache_empty |
| 336 | + |
| 337 | + # Create another model and get the memory_reserved when torch.no_grad has not been enabled |
| 338 | + # after export |
| 339 | + model_without_no_grad = _get_bert_for_sequence_classification_model(device) |
| 340 | + model_without_no_grad = ORTModule(model_without_no_grad) |
| 341 | + mem_reserved_after_export_without_torch_no_grad = 0 |
| 342 | + with patch('torch.no_grad'): |
| 343 | + model_without_no_grad(x, y, None, None, None, None, z) |
| 344 | + mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device) |
| 345 | + |
| 346 | + assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad |
| 347 | + assert mem_reserved_before_export < mem_reserved_after_export_with_torch_no_grad |
0 commit comments