diff --git a/arrayfire_wrapper/_backend.py b/arrayfire_wrapper/_backend.py index ed3864d..2644831 100644 --- a/arrayfire_wrapper/_backend.py +++ b/arrayfire_wrapper/_backend.py @@ -49,11 +49,18 @@ def _get_backend_path_config() -> _BackendPathConfig: platform_name = platform.system() cuda_found = False - # Try to use user provided AF_PATH if explicitly set - af_path = os.environ.get("AF_PATH", None) - af_is_user_path = af_path is not None + # try to use user provided AF_PATH if explicitly set + try: + af_path = Path(os.environ["AF_PATH"]) + af_is_user_path = True + except KeyError: + af_path = None + af_is_user_path = False - cuda_path = os.environ.get("CUDA_PATH", None) + try: + cuda_path = Path(os.environ["CUDA_PATH"]) + except KeyError: + cuda_path = None # Try to find default arrayfire installation paths if platform_name == _SupportedPlatforms.windows.value or _SupportedPlatforms.is_cygwin(platform_name):