Skip to content

Commit 9932ee4

Browse files
made MaskFormerModelTest faster (huggingface#15942)
1 parent e8efaec commit 9932ee4

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

tests/maskformer/test_modeling_maskformer.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121

2222
from tests.test_modeling_common import floats_tensor
23-
from transformers import MaskFormerConfig, is_torch_available, is_vision_available
23+
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
2424
from transformers.file_utils import cached_property
2525
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
2626

@@ -47,12 +47,12 @@ def __init__(
4747
batch_size=2,
4848
is_training=True,
4949
use_auxiliary_loss=False,
50-
num_queries=100,
50+
num_queries=10,
5151
num_channels=3,
52-
min_size=384,
53-
max_size=640,
54-
num_labels=150,
55-
mask_feature_size=256,
52+
min_size=32 * 4,
53+
max_size=32 * 6,
54+
num_labels=4,
55+
mask_feature_size=32,
5656
):
5757
self.parent = parent
5858
self.batch_size = batch_size
@@ -79,11 +79,20 @@ def prepare_config_and_inputs(self):
7979
return config, pixel_values, pixel_mask, mask_labels, class_labels
8080

8181
def get_config(self):
82-
return MaskFormerConfig(
83-
num_queries=self.num_queries,
82+
return MaskFormerConfig.from_backbone_and_decoder_configs(
83+
backbone_config=SwinConfig(
84+
depths=[1, 1, 1, 1],
85+
),
86+
decoder_config=DetrConfig(
87+
decoder_ffn_dim=128,
88+
num_queries=self.num_queries,
89+
decoder_attention_heads=2,
90+
d_model=self.mask_feature_size,
91+
),
92+
mask_feature_size=self.mask_feature_size,
93+
fpn_feature_size=self.mask_feature_size,
8494
num_channels=self.num_channels,
8595
num_labels=self.num_labels,
86-
mask_feature_size=self.mask_feature_size,
8796
)
8897

8998
def prepare_config_and_inputs_for_common(self):
@@ -161,7 +170,6 @@ def comm_check_on_output(result):
161170

162171

163172
@require_torch
164-
@slow
165173
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
166174

167175
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
@@ -221,11 +229,11 @@ def test_model_from_pretrained(self):
221229
model = MaskFormerModel.from_pretrained(model_name)
222230
self.assertIsNotNone(model)
223231

224-
@slow
225232
def test_model_with_labels(self):
233+
size = (self.model_tester.min_size,) * 2
226234
inputs = {
227-
"pixel_values": torch.randn((2, 3, 384, 384)),
228-
"mask_labels": torch.randn((2, 10, 384, 384)),
235+
"pixel_values": torch.randn((2, 3, *size)),
236+
"mask_labels": torch.randn((2, 10, *size)),
229237
"class_labels": torch.zeros(2, 10).long(),
230238
}
231239

0 commit comments

Comments
 (0)