Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jun 18, 2024
1 parent 6a81a6d commit 542a439
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

# pylint: disable=missing-class-docstring

BELOW_PY_3_9: bool = sys.version_info < (3, 9)
BELOW_PY_3_10: bool = sys.version_info < (3, 10)


def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any:
"""wrapper for testing the load function, which accounts for version differences"""
if BELOW_PY_3_9:
if BELOW_PY_3_10:
with pytest.warns(UserWarning) as record:
loaded = cls.load(data)
print([x.message for x in record])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

print(f"{SUPPORS_KW_ONLY = }")

BELOW_PY_3_9: bool = sys.version_info < (3, 9)
BELOW_PY_3_10: bool = sys.version_info < (3, 10)


def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any:
"""wrapper for testing the load function, which accounts for version differences"""
if BELOW_PY_3_9:
if BELOW_PY_3_10:
with pytest.warns(UserWarning) as record:
loaded = cls.load(data)
print([x.message for x in record])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# pylint: disable=missing-class-docstring, unused-variable


BELOW_PY_3_9: bool = sys.version_info < (3, 9)
BELOW_PY_3_10: bool = sys.version_info < (3, 10)


def _loading_test_wrapper(cls, data, assert_record_len: int | None = None) -> Any:
"""wrapper for testing the load function, which accounts for version differences"""
if BELOW_PY_3_9:
if BELOW_PY_3_10:
with pytest.warns(UserWarning) as record:
loaded = cls.load(data)
print([x.message for x in record])
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_mlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_checkpoint_paths_for_run():
assert checkpoint_paths == [(123, checkpoint1_path), (456, checkpoint2_path)]


BELOW_PY_3_9: bool = sys.version_info < (3, 9)
BELOW_PY_3_10: bool = sys.version_info < (3, 9)


def test_register_method():
Expand All @@ -51,14 +51,14 @@ class TestEvalsB:
def other_eval_function():
pass

if BELOW_PY_3_9:
if BELOW_PY_3_10:
assert len(record) == 2
else:
assert len(record) == 0

evalsA = TestEvalsA.evals
evalsB = TestEvalsB.evals
if BELOW_PY_3_9:
if BELOW_PY_3_10:
assert len(evalsA) == 1
assert len(evalsB) == 1
else:
Expand Down

0 comments on commit 542a439

Please sign in to comment.