diff --git a/CHANGELOG.md b/CHANGELOG.md index 2819df9d0..3b68a5ea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `requires` wrapper ([#1056](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/1056)) ### Changed diff --git a/src/pl_bolts/utils/_dependency.py b/src/pl_bolts/utils/_dependency.py new file mode 100644 index 000000000..81ce2d912 --- /dev/null +++ b/src/pl_bolts/utils/_dependency.py @@ -0,0 +1,28 @@ +import functools +import os +from typing import Any, Callable + +from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache + + +# ToDo: replace with utils wrapper after 0.10 is released +def requires(*module_path_version: str) -> Callable: + """Wrapper for enforcing certain requirements for a particular class or function.""" + + def decorator(func: Callable) -> Callable: + reqs = [ + ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver) + for mod_ver in module_path_version + ] + available = all(map(bool, reqs)) + if not available: + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + msg = os.linesep.join([repr(r) for r in reqs if not bool(r)]) + raise ModuleNotFoundError(f"Required dependencies not available: \n{msg}") + + return wrapper + return func + + return decorator diff --git a/tests/utils/test_dependency.py b/tests/utils/test_dependency.py new file mode 100644 index 000000000..a4fd42da3 --- /dev/null +++ b/tests/utils/test_dependency.py @@ -0,0 +1,28 @@ +import pytest +from pl_bolts.utils._dependency import requires + + +@requires("torch") +def using_torch(): + return True + + +@requires("torch.anything.wrong") +def using_torch_wrong_path(): + return True + + +@requires("torch>99.0") +def using_torch_bad_version(): + return True + + +def test_requires_pass(): + assert using_torch() is True + + +def test_requires_fail(): + with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"): + assert using_torch_wrong_path() + with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"): + assert using_torch_bad_version()