Skip to content

Commit

Permalink
Use flag() CM instead of custom one
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 20, 2023
1 parent 6a9463c commit 2a7b84e
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"


@contextlib.contextmanager
def disable_tf32():
previous = torch.backends.cudnn.allow_tf32
torch.backends.cudnn.allow_tf32 = False
try:
yield
finally:
torch.backends.cudnn.allow_tf32 = previous


def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)]

Expand Down Expand Up @@ -681,7 +671,7 @@ def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)


@disable_tf32() # see: https://github.com/pytorch/vision/issues/7618
@torch.backends.cudnn.flags(allow_tf32=False) # see: https://github.com/pytorch/vision/issues/7618
@pytest.mark.parametrize("model_fn", list_model_fns(models))
@pytest.mark.parametrize("dev", cpu_and_cuda())
def test_classification_model(model_fn, dev):
Expand Down

0 comments on commit 2a7b84e

Please sign in to comment.