Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: fix pytest errors on higher versions of PyTorch #697

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions .github/workflows/bagua-python-package-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,98 @@ jobs:
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
build-ten:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why adding so many steps on different containers here, rather than changing above **pytorch-1.9.0** directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangraying This is just to test if it can run properly on different versions of PyTorch. If this PR needs to be merged, the final testing environment will either PyTorch 1.9.0 or PyTorch 1.13.0.
If the test case will to be changed to PyTorch 1.13.0 version later, we can still keep the PyTorch 1.9.0 testing environment here.

Copy link
Member

@wangraying wangraying Apr 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked pytorch lightning's ci, their pytest cpu test supporting multiple torch versions, but gpu test mainly running on latest pytorch release. Maybe we could do similarly, or just maintaining all ci tests on pytorch 1.13?

runs-on: ubuntu-latest
container: baguasys/bagua:master-pytorch-1.10.0-cuda11.3-cudnn8
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: rustup default stable
- name: Install with pip
run: |
python -m pip install --pre .
- name: Install pytest
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-timeout
- name: Test with pytest
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
build-eleven:
runs-on: ubuntu-latest
container: baguasys/bagua:master-pytorch-1.11.0-cuda11.3-cudnn8
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: rustup default stable
- name: Install with pip
run: |
python -m pip install --pre .
- name: Install pytest
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-timeout
- name: Test with pytest
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
build-twelve:
runs-on: ubuntu-latest
container: baguasys/bagua:master-pytorch-1.12.0-cuda11.3-cudnn8
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: rustup default stable
- name: Install with pip
run: |
python -m pip install --pre .
- name: Install pytest
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-timeout
- name: Test with pytest
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
build-thirteen:
runs-on: ubuntu-latest
container: baguasys/bagua:master-pytorch-1.13.0-cuda11.6-cudnn8
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: rustup default stable
- name: Install with pip
run: |
python -m pip install --pre .
- name: Install pytest
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-timeout
- name: Test with pytest
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
build-twenty:
runs-on: ubuntu-latest
container: baguasys/bagua:master-pytorch-2.0.0-cuda11.7-cudnn8
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: rustup default stable
- name: Install with pip
run: |
python -m pip install --pre .
- name: Install pytest
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-timeout
- name: Test with pytest
run: |
rm -rf bagua bagua_core
pytest -s --timeout=300 --timeout_method=thread
3 changes: 3 additions & 0 deletions bagua/torch_api/contrib/sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
) <= LooseVersion("1.6.0")
_SYNC_BN_V3 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_SYNC_BN_V4 = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
_SYNC_BN_V5 = LooseVersion(torch.__version__) >= LooseVersion("1.10.0")
_SYNC_BN_V6 = LooseVersion(torch.__version__) >= LooseVersion("1.11.0")
_SYNC_BN_V7 = LooseVersion(torch.__version__) >= LooseVersion("2.0.0")


class SyncBatchNorm(_BatchNorm):
Expand Down
13 changes: 9 additions & 4 deletions tests/internal/torch/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@

import torch
import torch.cuda
from torch._six import string_classes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder does torch remove this torch._six in higher versions? do you know which version?
It's better for us to remove these dependencies on private modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangraying torch._six has been removed from PyTorch 2.0.0.

if hasattr(torch, "_six"):
from torch._six import string_classes
else:
string_classes = str
import torch.backends.cudnn
import torch.backends.mkl
from enum import Enum
Expand Down Expand Up @@ -318,7 +322,7 @@ def shell(command, cwd=None, env=None):
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(
command, torch._six.string_classes
command, string_classes
), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
return wait_for_process(p)
Expand Down Expand Up @@ -1040,7 +1044,7 @@ def check_slow_test_from_stats(test):
global slow_tests_dict
if slow_tests_dict is None:
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
url = "https://raw.githubusercontent.com/pytorch/test-infra/master/stats/slow-tests.json"
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json"
slow_tests_dict = fetch_and_cache(".pytorch-slow-tests", url)
else:
slow_tests_dict = {}
Expand Down Expand Up @@ -1374,10 +1378,11 @@ def compare_with_numpy(
torch.bfloat16: (0.016, 1e-5),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
if hasattr(torch, "complex32"): # torch.complex32 has been removed from 1.11.0
dtype_precisions[torch.complex32] = (0.001, 1e-5)

# Returns the "default" rtol and atol for comparing scalars or
# tensors of the given dtypes.
Expand Down
99 changes: 87 additions & 12 deletions tests/torch_api/data_parallel/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch import nn
from torch._six import string_classes

if hasattr(torch, "_six"):
from torch._six import string_classes
else:
string_classes = str
from bagua.torch_api.data_parallel import DistributedDataParallel
from tests.internal.torch.common_distributed import (
MultiProcessTestCase,
Expand All @@ -44,6 +48,7 @@
)

import bagua.torch_api.data_parallel.functional as bagua_dist
from bagua.torch_api.contrib.sync_batchnorm import _SYNC_BN_V5, _SYNC_BN_V6, _SYNC_BN_V7

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should couple this test with sync bn together.

Copy link
Contributor Author

@woqidaideshi woqidaideshi Apr 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangraying This approach aims to make the test cases compatible with multiple versions of PyTorch.

Do you have any other suggestions for addressing the compatibility issues of the test cases with multiple versions of PyTorch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree using LooseVersion** as the condition, but it's kind of confusing to import the condition from sync bn.

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
Expand Down Expand Up @@ -128,6 +133,7 @@ def num_keys_total(self):
return 5


@unittest.skipIf(_SYNC_BN_V7, "Skip FileStoreTest for torch >= 2.0.0")
class FileStoreTest(TestCase, StoreTestBase):
def setUp(self):
super(FileStoreTest, self).setUp()
Expand All @@ -150,6 +156,7 @@ def _create_store(self):
return store


@unittest.skipIf(_SYNC_BN_V7, "Skip PrefixFileStoreTest for torch >= 2.0.0")
class PrefixFileStoreTest(TestCase, StoreTestBase):
def setUp(self):
super(PrefixFileStoreTest, self).setUp()
Expand All @@ -172,16 +179,16 @@ def test_address_already_in_use(self):
if sys.platform == "win32":
err_msg_reg = "Only one usage of each socket address*"
else:
err_msg_reg = "^Address already in use$"
err_msg_reg = "Address already in use"
with self.assertRaisesRegex(RuntimeError, err_msg_reg):
addr = DEFAULT_HOSTNAME
port = find_free_port()

# Use noqa to silence flake8.
# Need to store in an unused variable here to ensure the first
# object is not destroyed before the second object is created.
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841
store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841

# The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
# the user and one additional key used for coordinate all the workers.
Expand Down Expand Up @@ -845,8 +852,15 @@ def test_single_limit_single_dtype(self):
torch.empty([100], dtype=torch.float),
torch.empty([50], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0], [1], [2], [3]], result)
if _SYNC_BN_V5:
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
tensors, [400]
)
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
self.assertEqual([[0], [1], [2], [3]], result)
else:
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0], [1], [2], [3]], result)

def test_single_limit_multi_dtype(self):
tensors = [
Expand All @@ -857,8 +871,15 @@ def test_single_limit_multi_dtype(self):
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
if _SYNC_BN_V5:
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
tensors, [400]
)
self.assertTrue(all(size_lim == 400 for size_lim in per_bucket_size_limits))
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
else:
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)

def test_multi_limit_single_dtype(self):
tensors = [
Expand All @@ -867,8 +888,15 @@ def test_multi_limit_single_dtype(self):
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
self.assertEqual([[0], [1, 2], [3]], result)
if _SYNC_BN_V5:
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
tensors, [40, 80]
)
self.assertEqual(per_bucket_size_limits, [40, 80, 80])
self.assertEqual([[0], [1, 2], [3]], result)
else:
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
self.assertEqual([[0], [1, 2], [3]], result)

def test_multi_limit_multi_dtype(self):
tensors = [
Expand All @@ -879,8 +907,15 @@ def test_multi_limit_multi_dtype(self):
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
if _SYNC_BN_V5:
result, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
tensors, [200, 400]
)
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
self.assertEqual(per_bucket_size_limits, [200, 200, 400, 400])
else:
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)


class AbstractCommTest(object):
Expand Down Expand Up @@ -1032,6 +1067,7 @@ def tearDown(self):
except OSError:
pass

@unittest.skipIf(_SYNC_BN_V6, "Skip test for torch >= 1.11.0")
def test_distributed_debug_mode(self):
# Default should be off
default_debug_mode = dist._get_debug_mode()
Expand All @@ -1057,6 +1093,45 @@ def test_distributed_debug_mode(self):
with self.assertRaisesRegex(RuntimeError, "to be one of"):
dist._get_debug_mode()

@unittest.skipIf(not _SYNC_BN_V6, "Skip test for torch < 1.11.0")
def test_debug_level(self):
try:
del os.environ["TORCH_DISTRIBUTED_DEBUG"]
except KeyError:
pass

dist.set_debug_level_from_env()
# Default should be off
default_debug_mode = dist.get_debug_level()
self.assertEqual(default_debug_mode, dist.DebugLevel.OFF)
mapping = {
"OFF": dist.DebugLevel.OFF,
"off": dist.DebugLevel.OFF,
"oFf": dist.DebugLevel.OFF,
"INFO": dist.DebugLevel.INFO,
"info": dist.DebugLevel.INFO,
"INfO": dist.DebugLevel.INFO,
"DETAIL": dist.DebugLevel.DETAIL,
"detail": dist.DebugLevel.DETAIL,
"DeTaIl": dist.DebugLevel.DETAIL,
}
invalid_debug_modes = ["foo", 0, 1, -1]

for mode in mapping.keys():
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
dist.set_debug_level_from_env()
set_debug_mode = dist.get_debug_level()
self.assertEqual(
set_debug_mode,
mapping[mode],
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
)

for mode in invalid_debug_modes:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"):
dist.set_debug_level_from_env()


if __name__ == "__main__":
assert (
Expand Down