20
20
import numpy as np
21
21
22
22
from tests .test_modeling_common import floats_tensor
23
- from transformers import MaskFormerConfig , is_torch_available , is_vision_available
23
+ from transformers import DetrConfig , MaskFormerConfig , SwinConfig , is_torch_available , is_vision_available
24
24
from transformers .file_utils import cached_property
25
25
from transformers .testing_utils import require_torch , require_vision , slow , torch_device
26
26
@@ -47,12 +47,12 @@ def __init__(
47
47
batch_size = 2 ,
48
48
is_training = True ,
49
49
use_auxiliary_loss = False ,
50
- num_queries = 100 ,
50
+ num_queries = 10 ,
51
51
num_channels = 3 ,
52
- min_size = 384 ,
53
- max_size = 640 ,
54
- num_labels = 150 ,
55
- mask_feature_size = 256 ,
52
+ min_size = 32 * 4 ,
53
+ max_size = 32 * 6 ,
54
+ num_labels = 4 ,
55
+ mask_feature_size = 32 ,
56
56
):
57
57
self .parent = parent
58
58
self .batch_size = batch_size
@@ -79,11 +79,20 @@ def prepare_config_and_inputs(self):
79
79
return config , pixel_values , pixel_mask , mask_labels , class_labels
80
80
81
81
def get_config (self ):
82
- return MaskFormerConfig (
83
- num_queries = self .num_queries ,
82
+ return MaskFormerConfig .from_backbone_and_decoder_configs (
83
+ backbone_config = SwinConfig (
84
+ depths = [1 , 1 , 1 , 1 ],
85
+ ),
86
+ decoder_config = DetrConfig (
87
+ decoder_ffn_dim = 128 ,
88
+ num_queries = self .num_queries ,
89
+ decoder_attention_heads = 2 ,
90
+ d_model = self .mask_feature_size ,
91
+ ),
92
+ mask_feature_size = self .mask_feature_size ,
93
+ fpn_feature_size = self .mask_feature_size ,
84
94
num_channels = self .num_channels ,
85
95
num_labels = self .num_labels ,
86
- mask_feature_size = self .mask_feature_size ,
87
96
)
88
97
89
98
def prepare_config_and_inputs_for_common (self ):
@@ -161,7 +170,6 @@ def comm_check_on_output(result):
161
170
162
171
163
172
@require_torch
164
- @slow
165
173
class MaskFormerModelTest (ModelTesterMixin , unittest .TestCase ):
166
174
167
175
all_model_classes = (MaskFormerModel , MaskFormerForInstanceSegmentation ) if is_torch_available () else ()
@@ -221,11 +229,11 @@ def test_model_from_pretrained(self):
221
229
model = MaskFormerModel .from_pretrained (model_name )
222
230
self .assertIsNotNone (model )
223
231
224
- @slow
225
232
def test_model_with_labels (self ):
233
+ size = (self .model_tester .min_size ,) * 2
226
234
inputs = {
227
- "pixel_values" : torch .randn ((2 , 3 , 384 , 384 )),
228
- "mask_labels" : torch .randn ((2 , 10 , 384 , 384 )),
235
+ "pixel_values" : torch .randn ((2 , 3 , * size )),
236
+ "mask_labels" : torch .randn ((2 , 10 , * size )),
229
237
"class_labels" : torch .zeros (2 , 10 ).long (),
230
238
}
231
239
0 commit comments