-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ci: fix pytest errors on higher versions of PyTorch #697
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,7 +132,11 @@ | |
|
||
import torch | ||
import torch.cuda | ||
from torch._six import string_classes | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder does torch remove this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wangraying torch._six has been removed from PyTorch 2.0.0. |
||
if hasattr(torch, "_six"): | ||
woqidaideshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from torch._six import string_classes | ||
else: | ||
string_classes = str | ||
import torch.backends.cudnn | ||
import torch.backends.mkl | ||
from enum import Enum | ||
|
@@ -318,7 +322,7 @@ def shell(command, cwd=None, env=None): | |
# | ||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 | ||
assert not isinstance( | ||
command, torch._six.string_classes | ||
command, string_classes | ||
), "Command to shell should be a list or tuple of tokens" | ||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env) | ||
return wait_for_process(p) | ||
|
@@ -1040,7 +1044,7 @@ def check_slow_test_from_stats(test): | |
global slow_tests_dict | ||
if slow_tests_dict is None: | ||
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1": | ||
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json" | ||
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json" | ||
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests", url) | ||
else: | ||
slow_tests_dict = {} | ||
|
@@ -1374,10 +1378,11 @@ def compare_with_numpy( | |
torch.bfloat16: (0.016, 1e-5), | ||
torch.float32: (1.3e-6, 1e-5), | ||
torch.float64: (1e-7, 1e-7), | ||
torch.complex32: (0.001, 1e-5), | ||
torch.complex64: (1.3e-6, 1e-5), | ||
torch.complex128: (1e-7, 1e-7), | ||
} | ||
if hasattr(torch, "complex32"): # torch.complex32 has been removed from 1.11.0 | ||
dtype_precisions[torch.complex32] = (0.001, 1e-5) | ||
|
||
# Returns the "default" rtol and atol for comparing scalars or | ||
# tensors of the given dtypes. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,11 @@ | |
import torch.multiprocessing as mp | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch._six import string_classes | ||
|
||
if hasattr(torch, "_six"): | ||
woqidaideshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from torch._six import string_classes | ||
else: | ||
string_classes = str | ||
from bagua.torch_api.data_parallel import DistributedDataParallel | ||
from tests.internal.torch.common_distributed import ( | ||
MultiProcessTestCase, | ||
|
@@ -44,6 +48,7 @@ | |
) | ||
|
||
import bagua.torch_api.data_parallel.functional as bagua_dist | ||
from bagua.torch_api.contrib.sync_batchnorm import _SYNC_BN_V5, _SYNC_BN_V6, _SYNC_BN_V7 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should couple this test with sync bn together. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wangraying This approach aims to make the test cases compatible with multiple versions of PyTorch. Do you have any other suggestions for addressing the compatibility issues of the test cases with multiple versions of PyTorch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I agree using |
||
# load_tests from common_utils is used to automatically filter tests for | ||
# sharding on sandcastle. This line silences flake warnings | ||
|
@@ -128,6 +133,7 @@ def num_keys_total(self): | |
return 5 | ||
|
||
|
||
@unittest.skipIf(_SYNC_BN_V7, "Skip FileStoreTest for torch >= 2.0.0") | ||
class FileStoreTest(TestCase, StoreTestBase): | ||
def setUp(self): | ||
super(FileStoreTest, self).setUp() | ||
|
@@ -150,6 +156,7 @@ def _create_store(self): | |
return store | ||
|
||
|
||
@unittest.skipIf(_SYNC_BN_V7, "Skip PrefixFileStoreTest for torch >= 2.0.0") | ||
class PrefixFileStoreTest(TestCase, StoreTestBase): | ||
def setUp(self): | ||
super(PrefixFileStoreTest, self).setUp() | ||
|
@@ -172,16 +179,16 @@ def test_address_already_in_use(self): | |
if sys.platform == "win32": | ||
err_msg_reg = "Only one usage of each socket address*" | ||
else: | ||
err_msg_reg = "^Address already in use$" | ||
err_msg_reg = "Address already in use" | ||
with self.assertRaisesRegex(RuntimeError, err_msg_reg): | ||
addr = DEFAULT_HOSTNAME | ||
port = find_free_port() | ||
|
||
# Use noqa to silence flake8. | ||
# Need to store in an unused variable here to ensure the first | ||
# object is not destroyed before the second object is created. | ||
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 | ||
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 | ||
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 | ||
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 | ||
|
||
# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by | ||
# the user and one additional key used for coordinate all the workers. | ||
|
@@ -845,8 +852,15 @@ def test_single_limit_single_dtype(self): | |
torch.empty([100], dtype=torch.float), | ||
torch.empty([50], dtype=torch.float), | ||
] | ||
result = dist._compute_bucket_assignment_by_size(tensors, [400]) | ||
self.assertEqual([[0], [1], [2], [3]], result) | ||
if _SYNC_BN_V5: | ||
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( | ||
tensors, [400] | ||
) | ||
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) | ||
self.assertEqual([[0], [1], [2], [3]], result) | ||
else: | ||
result = dist._compute_bucket_assignment_by_size(tensors, [400]) | ||
self.assertEqual([[0], [1], [2], [3]], result) | ||
|
||
def test_single_limit_multi_dtype(self): | ||
tensors = [ | ||
|
@@ -857,8 +871,15 @@ def test_single_limit_multi_dtype(self): | |
torch.empty([50], dtype=torch.float), | ||
torch.empty([25], dtype=torch.double), | ||
] | ||
result = dist._compute_bucket_assignment_by_size(tensors, [400]) | ||
self.assertEqual([[0, 2], [1, 3], [4], [5]], result) | ||
if _SYNC_BN_V5: | ||
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( | ||
tensors, [400] | ||
) | ||
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) | ||
self.assertEqual([[0, 2], [1, 3], [4], [5]], result) | ||
else: | ||
result = dist._compute_bucket_assignment_by_size(tensors, [400]) | ||
self.assertEqual([[0, 2], [1, 3], [4], [5]], result) | ||
|
||
def test_multi_limit_single_dtype(self): | ||
tensors = [ | ||
|
@@ -867,8 +888,15 @@ def test_multi_limit_single_dtype(self): | |
torch.empty([10], dtype=torch.float), | ||
torch.empty([10], dtype=torch.float), | ||
] | ||
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) | ||
self.assertEqual([[0], [1, 2], [3]], result) | ||
if _SYNC_BN_V5: | ||
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( | ||
tensors, [40, 80] | ||
) | ||
self.assertEqual(per_bucket_size_limits, [40, 80, 80]) | ||
self.assertEqual([[0], [1, 2], [3]], result) | ||
else: | ||
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) | ||
self.assertEqual([[0], [1, 2], [3]], result) | ||
|
||
def test_multi_limit_multi_dtype(self): | ||
tensors = [ | ||
|
@@ -879,8 +907,15 @@ def test_multi_limit_multi_dtype(self): | |
torch.empty([50], dtype=torch.float), | ||
torch.empty([25], dtype=torch.double), | ||
] | ||
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) | ||
self.assertEqual([[0], [1], [2, 4], [3, 5]], result) | ||
if _SYNC_BN_V5: | ||
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( | ||
tensors, [200, 400] | ||
) | ||
self.assertEqual([[0], [1], [2, 4], [3, 5]], result) | ||
self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400]) | ||
else: | ||
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) | ||
self.assertEqual([[0], [1], [2, 4], [3, 5]], result) | ||
|
||
|
||
class AbstractCommTest(object): | ||
|
@@ -1032,6 +1067,7 @@ def tearDown(self): | |
except OSError: | ||
pass | ||
|
||
@unittest.skipIf(_SYNC_BN_V6, "Skip test for torch >= 1.11.0") | ||
def test_distributed_debug_mode(self): | ||
# Default should be off | ||
default_debug_mode = dist._get_debug_mode() | ||
|
@@ -1057,6 +1093,45 @@ def test_distributed_debug_mode(self): | |
with self.assertRaisesRegex(RuntimeError, "to be one of"): | ||
dist._get_debug_mode() | ||
|
||
@unittest.skipIf(not _SYNC_BN_V6, "Skip test for torch < 1.11.0") | ||
def test_debug_level(self): | ||
try: | ||
del os.environ["TORCH_DISTRIBUTED_DEBUG"] | ||
except KeyError: | ||
pass | ||
|
||
dist.set_debug_level_from_env() | ||
# Default should be off | ||
default_debug_mode = dist.get_debug_level() | ||
self.assertEqual(default_debug_mode, dist.DebugLevel.OFF) | ||
mapping = { | ||
"OFF": dist.DebugLevel.OFF, | ||
"off": dist.DebugLevel.OFF, | ||
"oFf": dist.DebugLevel.OFF, | ||
"INFO": dist.DebugLevel.INFO, | ||
"info": dist.DebugLevel.INFO, | ||
"INfO": dist.DebugLevel.INFO, | ||
"DETAIL": dist.DebugLevel.DETAIL, | ||
"detail": dist.DebugLevel.DETAIL, | ||
"DeTaIl": dist.DebugLevel.DETAIL, | ||
} | ||
invalid_debug_modes = ["foo", 0, 1, -1] | ||
|
||
for mode in mapping.keys(): | ||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) | ||
dist.set_debug_level_from_env() | ||
set_debug_mode = dist.get_debug_level() | ||
self.assertEqual( | ||
set_debug_mode, | ||
mapping[mode], | ||
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}", | ||
) | ||
|
||
for mode in invalid_debug_modes: | ||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) | ||
with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"): | ||
woqidaideshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dist.set_debug_level_from_env() | ||
|
||
|
||
if __name__ == "__main__": | ||
woqidaideshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why adding so many steps on different containers here, rather than changing above
**pytorch-1.9.0**
directly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wangraying This is just to test if it can run properly on different versions of PyTorch. If this PR needs to be merged, the final testing environment will either PyTorch 1.9.0 or PyTorch 1.13.0.
If the test case will to be changed to PyTorch 1.13.0 version later, we can still keep the PyTorch 1.9.0 testing environment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just checked pytorch lightning's ci, their pytest cpu test supporting multiple torch versions, but gpu test mainly running on latest pytorch release. Maybe we could do similarly, or just maintaining all ci tests on pytorch 1.13?