From 1f5fe7b2f0db0961842e62e6cb2d0fddeaf1e75a Mon Sep 17 00:00:00 2001 From: Yafen Fang Date: Fri, 14 Apr 2023 18:32:02 +0200 Subject: [PATCH 1/3] ci: fix pytest errors on higher versions of PyTorch --- bagua/torch_api/contrib/sync_batchnorm.py | 3 + tests/internal/torch/common_utils.py | 14 ++- .../data_parallel/test_c10d_common.py | 105 ++++++++++++++++-- 3 files changed, 105 insertions(+), 17 deletions(-) diff --git a/bagua/torch_api/contrib/sync_batchnorm.py b/bagua/torch_api/contrib/sync_batchnorm.py index 40788e305..b7a53ba7d 100644 --- a/bagua/torch_api/contrib/sync_batchnorm.py +++ b/bagua/torch_api/contrib/sync_batchnorm.py @@ -26,6 +26,9 @@ ) <= LooseVersion("1.6.0") _SYNC_BN_V3 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0") _SYNC_BN_V4 = LooseVersion(torch.__version__) >= LooseVersion("1.9.0") +_SYNC_BN_V5 = LooseVersion(torch.__version__) >= LooseVersion("1.10.0") +_SYNC_BN_V6 = LooseVersion(torch.__version__) >= LooseVersion("1.11.0") +_SYNC_BN_V7 = LooseVersion(torch.__version__) >= LooseVersion("2.0.0") class SyncBatchNorm(_BatchNorm): diff --git a/tests/internal/torch/common_utils.py b/tests/internal/torch/common_utils.py index 868b43fb5..8703a3352 100644 --- a/tests/internal/torch/common_utils.py +++ b/tests/internal/torch/common_utils.py @@ -132,7 +132,10 @@ import torch import torch.cuda -from torch._six import string_classes +if hasattr(torch, "_six"): + from torch._six import string_classes +else: + string_classes = str import torch.backends.cudnn import torch.backends.mkl from enum import Enum @@ -318,7 +321,7 @@ def shell(command, cwd=None, env=None): # # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323 assert not isinstance( - command, torch._six.string_classes + command, string_classes ), "Command to shell should be a list or tuple of tokens" p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env) return wait_for_process(p) @@ -1040,7 +1043,7 @@ def check_slow_test_from_stats(test): global slow_tests_dict if slow_tests_dict is None: if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1": - url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json" + url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json" slow_tests_dict = fetch_and_cache(".pytorch-slow-tests", url) else: slow_tests_dict = {} @@ -1374,10 +1377,11 @@ def compare_with_numpy( torch.bfloat16: (0.016, 1e-5), torch.float32: (1.3e-6, 1e-5), torch.float64: (1e-7, 1e-7), - torch.complex32: (0.001, 1e-5), torch.complex64: (1.3e-6, 1e-5), torch.complex128: (1e-7, 1e-7), } + if hasattr(torch, "complex32"): # torch.complex32 has been removed from 1.11.0 + dtype_precisions[torch.complex32] = (0.001, 1e-5) # Returns the "default" rtol and atol for comparing scalars or # tensors of the given dtypes. @@ -2138,7 +2142,7 @@ def runWithPytorchAPIUsageStderr(code): def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", 0)) + sock.bind(('localhost', 0)) _, port = sock.getsockname() return port diff --git a/tests/torch_api/data_parallel/test_c10d_common.py b/tests/torch_api/data_parallel/test_c10d_common.py index 41c249b4b..c99590945 100644 --- a/tests/torch_api/data_parallel/test_c10d_common.py +++ b/tests/torch_api/data_parallel/test_c10d_common.py @@ -24,7 +24,10 @@ import torch.multiprocessing as mp import torch.nn.functional as F from torch import nn -from torch._six import string_classes +if hasattr(torch, "_six"): + from torch._six import string_classes +else: + string_classes = str from bagua.torch_api.data_parallel import DistributedDataParallel from tests.internal.torch.common_distributed import ( MultiProcessTestCase, @@ -44,6 +47,7 @@ ) import bagua.torch_api.data_parallel.functional as bagua_dist +from bagua.torch_api.contrib.sync_batchnorm import _SYNC_BN_V5, _SYNC_BN_V6, _SYNC_BN_V7 # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -128,6 +132,9 @@ def num_keys_total(self): return 5 +@unittest.skipIf( + _SYNC_BN_V7, "Skip FileStoreTest for torch >= 2.0.0" +) class FileStoreTest(TestCase, StoreTestBase): def setUp(self): super(FileStoreTest, self).setUp() @@ -150,6 +157,9 @@ def _create_store(self): return store +@unittest.skipIf( + _SYNC_BN_V7, "Skip PrefixFileStoreTest for torch >= 2.0.0" +) class PrefixFileStoreTest(TestCase, StoreTestBase): def setUp(self): super(PrefixFileStoreTest, self).setUp() @@ -172,7 +182,7 @@ def test_address_already_in_use(self): if sys.platform == "win32": err_msg_reg = "Only one usage of each socket address*" else: - err_msg_reg = "^Address already in use$" + err_msg_reg = "Address already in use" with self.assertRaisesRegex(RuntimeError, err_msg_reg): addr = DEFAULT_HOSTNAME port = find_free_port() @@ -180,8 +190,8 @@ def test_address_already_in_use(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 - store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 + store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 + store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by # the user and one additional key used for coordinate all the workers. @@ -845,8 +855,15 @@ def test_single_limit_single_dtype(self): torch.empty([100], dtype=torch.float), torch.empty([50], dtype=torch.float), ] - result = dist._compute_bucket_assignment_by_size(tensors, [400]) - self.assertEqual([[0], [1], [2], [3]], result) + if _SYNC_BN_V5: + result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( + tensors, [400] + ) + self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) + self.assertEqual([[0], [1], [2], [3]], result) + else: + result = dist._compute_bucket_assignment_by_size(tensors, [400]) + self.assertEqual([[0], [1], [2], [3]], result) def test_single_limit_multi_dtype(self): tensors = [ @@ -857,8 +874,15 @@ def test_single_limit_multi_dtype(self): torch.empty([50], dtype=torch.float), torch.empty([25], dtype=torch.double), ] - result = dist._compute_bucket_assignment_by_size(tensors, [400]) - self.assertEqual([[0, 2], [1, 3], [4], [5]], result) + if _SYNC_BN_V5: + result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( + tensors, [400] + ) + self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) + self.assertEqual([[0, 2], [1, 3], [4], [5]], result) + else: + result = dist._compute_bucket_assignment_by_size(tensors, [400]) + self.assertEqual([[0, 2], [1, 3], [4], [5]], result) def test_multi_limit_single_dtype(self): tensors = [ @@ -867,8 +891,15 @@ def test_multi_limit_single_dtype(self): torch.empty([10], dtype=torch.float), torch.empty([10], dtype=torch.float), ] - result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) - self.assertEqual([[0], [1, 2], [3]], result) + if _SYNC_BN_V5: + result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( + tensors, [40, 80] + ) + self.assertEqual(per_bucket_size_limits, [40, 80, 80]) + self.assertEqual([[0], [1, 2], [3]], result) + else: + result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) + self.assertEqual([[0], [1, 2], [3]], result) def test_multi_limit_multi_dtype(self): tensors = [ @@ -879,8 +910,15 @@ def test_multi_limit_multi_dtype(self): torch.empty([50], dtype=torch.float), torch.empty([25], dtype=torch.double), ] - result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) - self.assertEqual([[0], [1], [2, 4], [3, 5]], result) + if _SYNC_BN_V5: + result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( + tensors, [200, 400] + ) + self.assertEqual([[0], [1], [2, 4], [3, 5]], result) + self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400]) + else: + result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) + self.assertEqual([[0], [1], [2, 4], [3, 5]], result) class AbstractCommTest(object): @@ -1032,6 +1070,9 @@ def tearDown(self): except OSError: pass + @unittest.skipIf( + _SYNC_BN_V6, "Skip test for torch >= 1.11.0" + ) def test_distributed_debug_mode(self): # Default should be off default_debug_mode = dist._get_debug_mode() @@ -1057,6 +1098,46 @@ def test_distributed_debug_mode(self): with self.assertRaisesRegex(RuntimeError, "to be one of"): dist._get_debug_mode() + @unittest.skipIf( + not _SYNC_BN_V6, "Skip test for torch < 1.11.0" + ) + def test_debug_level(self): + try: + del os.environ["TORCH_DISTRIBUTED_DEBUG"] + except KeyError: + pass + + dist.set_debug_level_from_env() + # Default should be off + default_debug_mode = dist.get_debug_level() + self.assertEqual(default_debug_mode, dist.DebugLevel.OFF) + mapping = { + "OFF": dist.DebugLevel.OFF, + "off": dist.DebugLevel.OFF, + "oFf": dist.DebugLevel.OFF, + "INFO": dist.DebugLevel.INFO, + "info": dist.DebugLevel.INFO, + "INfO": dist.DebugLevel.INFO, + "DETAIL": dist.DebugLevel.DETAIL, + "detail": dist.DebugLevel.DETAIL, + "DeTaIl": dist.DebugLevel.DETAIL, + } + invalid_debug_modes = ["foo", 0, 1, -1] + + for mode in mapping.keys(): + os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) + dist.set_debug_level_from_env() + set_debug_mode = dist.get_debug_level() + self.assertEqual( + set_debug_mode, + mapping[mode], + f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}", + ) + + for mode in invalid_debug_modes: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) + with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"): + dist.set_debug_level_from_env() if __name__ == "__main__": assert ( From 8a6bc10447d7b28f8f3a131cb064e91ac87a9946 Mon Sep 17 00:00:00 2001 From: Yafen Fang Date: Fri, 14 Apr 2023 18:46:23 +0200 Subject: [PATCH 2/3] ci: pytest on every version of PyTorch --- .../workflows/bagua-python-package-check.yml | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/.github/workflows/bagua-python-package-check.yml b/.github/workflows/bagua-python-package-check.yml index 01d753a9d..1285c94be 100644 --- a/.github/workflows/bagua-python-package-check.yml +++ b/.github/workflows/bagua-python-package-check.yml @@ -29,3 +29,98 @@ jobs: run: | rm -rf bagua bagua_core pytest -s --timeout=300 --timeout_method=thread + build-ten: + runs-on: ubuntu-latest + container: baguasys/bagua:master-pytorch-1.10.0-cuda11.3-cudnn8 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - run: rustup default stable + - name: Install with pip + run: | + python -m pip install --pre . + - name: Install pytest + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-timeout + - name: Test with pytest + run: | + rm -rf bagua bagua_core + pytest -s --timeout=300 --timeout_method=thread + build-eleven: + runs-on: ubuntu-latest + container: baguasys/bagua:master-pytorch-1.11.0-cuda11.3-cudnn8 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - run: rustup default stable + - name: Install with pip + run: | + python -m pip install --pre . + - name: Install pytest + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-timeout + - name: Test with pytest + run: | + rm -rf bagua bagua_core + pytest -s --timeout=300 --timeout_method=thread + build-twelve: + runs-on: ubuntu-latest + container: baguasys/bagua:master-pytorch-1.12.0-cuda11.3-cudnn8 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - run: rustup default stable + - name: Install with pip + run: | + python -m pip install --pre . + - name: Install pytest + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-timeout + - name: Test with pytest + run: | + rm -rf bagua bagua_core + pytest -s --timeout=300 --timeout_method=thread + build-thirteen: + runs-on: ubuntu-latest + container: baguasys/bagua:master-pytorch-1.13.0-cuda11.6-cudnn8 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - run: rustup default stable + - name: Install with pip + run: | + python -m pip install --pre . + - name: Install pytest + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-timeout + - name: Test with pytest + run: | + rm -rf bagua bagua_core + pytest -s --timeout=300 --timeout_method=thread + build-twenty: + runs-on: ubuntu-latest + container: baguasys/bagua:master-pytorch-2.0.0-cuda11.7-cudnn8 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - run: rustup default stable + - name: Install with pip + run: | + python -m pip install --pre . + - name: Install pytest + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-timeout + - name: Test with pytest + run: | + rm -rf bagua bagua_core + pytest -s --timeout=300 --timeout_method=thread From 34b857d10ad5d647d011fbd3ec3892c979247f72 Mon Sep 17 00:00:00 2001 From: Yafen Fang Date: Fri, 14 Apr 2023 19:05:31 +0200 Subject: [PATCH 3/3] ci: fix for blackfmt --- tests/internal/torch/common_utils.py | 5 +++-- .../data_parallel/test_c10d_common.py | 22 +++++++------------ 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/internal/torch/common_utils.py b/tests/internal/torch/common_utils.py index 8703a3352..4bac474f2 100644 --- a/tests/internal/torch/common_utils.py +++ b/tests/internal/torch/common_utils.py @@ -132,6 +132,7 @@ import torch import torch.cuda + if hasattr(torch, "_six"): from torch._six import string_classes else: @@ -1380,7 +1381,7 @@ def compare_with_numpy( torch.complex64: (1.3e-6, 1e-5), torch.complex128: (1e-7, 1e-7), } - if hasattr(torch, "complex32"): # torch.complex32 has been removed from 1.11.0 + if hasattr(torch, "complex32"): # torch.complex32 has been removed from 1.11.0 dtype_precisions[torch.complex32] = (0.001, 1e-5) # Returns the "default" rtol and atol for comparing scalars or @@ -2142,7 +2143,7 @@ def runWithPytorchAPIUsageStderr(code): def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(('localhost', 0)) + sock.bind(("localhost", 0)) _, port = sock.getsockname() return port diff --git a/tests/torch_api/data_parallel/test_c10d_common.py b/tests/torch_api/data_parallel/test_c10d_common.py index c99590945..87fc23074 100644 --- a/tests/torch_api/data_parallel/test_c10d_common.py +++ b/tests/torch_api/data_parallel/test_c10d_common.py @@ -24,6 +24,7 @@ import torch.multiprocessing as mp import torch.nn.functional as F from torch import nn + if hasattr(torch, "_six"): from torch._six import string_classes else: @@ -132,9 +133,7 @@ def num_keys_total(self): return 5 -@unittest.skipIf( - _SYNC_BN_V7, "Skip FileStoreTest for torch >= 2.0.0" -) +@unittest.skipIf(_SYNC_BN_V7, "Skip FileStoreTest for torch >= 2.0.0") class FileStoreTest(TestCase, StoreTestBase): def setUp(self): super(FileStoreTest, self).setUp() @@ -157,9 +156,7 @@ def _create_store(self): return store -@unittest.skipIf( - _SYNC_BN_V7, "Skip PrefixFileStoreTest for torch >= 2.0.0" -) +@unittest.skipIf(_SYNC_BN_V7, "Skip PrefixFileStoreTest for torch >= 2.0.0") class PrefixFileStoreTest(TestCase, StoreTestBase): def setUp(self): super(PrefixFileStoreTest, self).setUp() @@ -857,8 +854,8 @@ def test_single_limit_single_dtype(self): ] if _SYNC_BN_V5: result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size( - tensors, [400] - ) + tensors, [400] + ) self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits)) self.assertEqual([[0], [1], [2], [3]], result) else: @@ -1070,9 +1067,7 @@ def tearDown(self): except OSError: pass - @unittest.skipIf( - _SYNC_BN_V6, "Skip test for torch >= 1.11.0" - ) + @unittest.skipIf(_SYNC_BN_V6, "Skip test for torch >= 1.11.0") def test_distributed_debug_mode(self): # Default should be off default_debug_mode = dist._get_debug_mode() @@ -1098,9 +1093,7 @@ def test_distributed_debug_mode(self): with self.assertRaisesRegex(RuntimeError, "to be one of"): dist._get_debug_mode() - @unittest.skipIf( - not _SYNC_BN_V6, "Skip test for torch < 1.11.0" - ) + @unittest.skipIf(not _SYNC_BN_V6, "Skip test for torch < 1.11.0") def test_debug_level(self): try: del os.environ["TORCH_DISTRIBUTED_DEBUG"] @@ -1139,6 +1132,7 @@ def test_debug_level(self): with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"): dist.set_debug_level_from_env() + if __name__ == "__main__": assert ( not torch.cuda._initialized