4
4
from functools import lru_cache
5
5
from itertools import product
6
6
from typing import Callable , List , Tuple
7
- import unittest
8
7
9
8
import numpy as np
10
9
import pytest
21
20
from torchvision .models .feature_extraction import get_graph_node_names
22
21
23
22
23
+ OPTESTS = [
24
+ "test_schema" ,
25
+ "test_autograd_registration" ,
26
+ "test_faketensor" ,
27
+ "test_aot_dispatch_dynamic" ,
28
+ ]
29
+
30
+
24
31
# Context manager for setting deterministic flag and automatically
25
32
# resetting it to its original value
26
33
class DeterministicGuard :
@@ -464,9 +471,10 @@ def test_boxes_shape(self):
464
471
465
472
@pytest .mark .parametrize ("aligned" , (True , False ))
466
473
@pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
467
- @pytest .mark .parametrize ("x_dtype" , (torch .float16 , torch .float32 , torch .float64 ), ids = str )
474
+ @pytest .mark .parametrize ("x_dtype" , (torch .float16 , torch .float32 , torch .float64 )) # , ids=str)
468
475
@pytest .mark .parametrize ("contiguous" , (True , False ))
469
476
@pytest .mark .parametrize ("deterministic" , (True , False ))
477
+ @pytest .mark .opcheck_only_one ()
470
478
def test_forward (self , device , contiguous , deterministic , aligned , x_dtype , rois_dtype = None ):
471
479
if deterministic and device == "cpu" :
472
480
pytest .skip ("cpu is always deterministic, don't retest" )
@@ -484,6 +492,7 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois
484
492
@pytest .mark .parametrize ("deterministic" , (True , False ))
485
493
@pytest .mark .parametrize ("x_dtype" , (torch .float , torch .half ))
486
494
@pytest .mark .parametrize ("rois_dtype" , (torch .float , torch .half ))
495
+ @pytest .mark .opcheck_only_one ()
487
496
def test_autocast (self , aligned , deterministic , x_dtype , rois_dtype ):
488
497
with torch .cuda .amp .autocast ():
489
498
self .test_forward (
@@ -499,6 +508,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
499
508
@pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
500
509
@pytest .mark .parametrize ("contiguous" , (True , False ))
501
510
@pytest .mark .parametrize ("deterministic" , (True , False ))
511
+ @pytest .mark .opcheck_only_one ()
502
512
def test_backward (self , seed , device , contiguous , deterministic ):
503
513
if deterministic and device == "cpu" :
504
514
pytest .skip ("cpu is always deterministic, don't retest" )
@@ -513,6 +523,7 @@ def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
513
523
@pytest .mark .parametrize ("aligned" , (True , False ))
514
524
@pytest .mark .parametrize ("scale, zero_point" , ((1 , 0 ), (2 , 10 ), (0.1 , 50 )))
515
525
@pytest .mark .parametrize ("qdtype" , (torch .qint8 , torch .quint8 , torch .qint32 ))
526
+ @pytest .mark .opcheck_only_one ()
516
527
def test_qroialign (self , aligned , scale , zero_point , qdtype ):
517
528
"""Make sure quantized version of RoIAlign is close to float version"""
518
529
pool_size = 5
@@ -582,6 +593,15 @@ def test_jit_boxes_list(self):
582
593
self ._helper_jit_boxes_list (model )
583
594
584
595
596
+ optests .generate_opcheck_tests (
597
+ testcase = TestRoIAlign ,
598
+ namespaces = ["torchvision" ],
599
+ failures_dict_path = "test/optests_failures_dict.json" ,
600
+ additional_decorators = [],
601
+ test_utils = OPTESTS ,
602
+ )
603
+
604
+
585
605
class TestPSRoIAlign (RoIOpTester ):
586
606
mps_backward_atol = 5e-2
587
607
@@ -837,20 +857,13 @@ def test_batched_nms_implementations(self, seed):
837
857
empty = torch .empty ((0 ,), dtype = torch .int64 )
838
858
torch .testing .assert_close (empty , ops .batched_nms (empty , None , None , None ))
839
859
840
- data_dependent_torchvision_test_checks = [
841
- "test_schema" ,
842
- "test_autograd_registration" ,
843
- "test_faketensor" ,
844
- "test_aot_dispatch_dynamic" ,
845
- ]
846
860
847
861
optests .generate_opcheck_tests (
848
- TestNMS ,
849
- ["torchvision" ],
850
- {},
851
- "test/test_ops.py" ,
852
- [],
853
- data_dependent_torchvision_test_checks ,
862
+ testcase = TestNMS ,
863
+ namespaces = ["torchvision" ],
864
+ failures_dict_path = "test/optests_failures_dict.json" ,
865
+ additional_decorators = [],
866
+ test_utils = OPTESTS ,
854
867
)
855
868
856
869
0 commit comments