Skip to content

Commit

Permalink
[WIP] FEAT Decorator to purge accelerate env vars (#3252)
Browse files Browse the repository at this point in the history
* [WIP] FEAT Decorator to purge accelerate env vars

In some circumstances, calling certain classes or functions can result
in accelerate env vars being set and not being cleaned up afterwards. As
an example, when calling:

TrainingArguments(fp16=True, ...)

The following env var will be set:

ACCELERATE_MIXED_PRECISION=fp16

This can affect subsequent code, since the env var takes precedence over
TrainingArguments(fp16=False). This is especially relevant for unit
testing, where we want to avoid the individual tests to have side
effects on one another. Decorate the unit test function or whole class
with this decorator to ensure that after each test, the env vars are
cleaned up. This works for both unittest.TestCase and normal
classes (pytest); it also works when decorating the parent class.

In its current state, this PR adds the new decorator and tests it, but
the decorator is not yet applied to potentially problematic functions or
classes.

* Linter

* Refactor code to be more readable

---------

Co-authored-by: [[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL <[email protected]>
  • Loading branch information
BenjaminBossan and muellerzr authored Nov 25, 2024
1 parent e11d3ce commit 29be478
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ When setting up 🤗 Accelerate for the first time, rather than running `acceler

[[autodoc]] utils.environment.override_numa_affinity

[[autodoc]] utils.purge_accelerate_environment

## Memory

[[autodoc]] utils.find_executable_batch_size
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
parse_choice_from_env,
parse_flag_from_env,
patch_environment,
purge_accelerate_environment,
set_numa_affinity,
str_to_bool,
)
Expand Down
64 changes: 63 additions & 1 deletion src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import lru_cache
from functools import lru_cache, wraps
from shutil import which
from typing import List, Optional

Expand Down Expand Up @@ -345,3 +345,65 @@ def patch_environment(**kwargs):
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)


def purge_accelerate_environment(func_or_cls):
"""Decorator to clean up accelerate environment variables set by the decorated class or function.
In some circumstances, calling certain classes or functions can result in accelerate env vars being set and not
being cleaned up afterwards. As an example, when calling:
TrainingArguments(fp16=True, ...)
The following env var will be set:
ACCELERATE_MIXED_PRECISION=fp16
This can affect subsequent code, since the env var takes precedence over TrainingArguments(fp16=False). This is
especially relevant for unit testing, where we want to avoid the individual tests to have side effects on one
another. Decorate the unit test function or whole class with this decorator to ensure that after each test, the env
vars are cleaned up. This works for both unittest.TestCase and normal classes (pytest); it also works when
decorating the parent class.
"""
prefix = "ACCELERATE_"

@contextmanager
def env_var_context():
# Store existing accelerate env vars
existing_vars = {k: v for k, v in os.environ.items() if k.startswith(prefix)}
try:
yield
finally:
# Restore original env vars or remove new ones
for key in [k for k in os.environ if k.startswith(prefix)]:
if key in existing_vars:
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)

def wrap_function(func):
@wraps(func)
def wrapper(*args, **kwargs):
with env_var_context():
return func(*args, **kwargs)

wrapper._accelerate_is_purged_environment_wrapped = True
return wrapper

if not isinstance(func_or_cls, type):
return wrap_function(func_or_cls)

# Handle classes by wrapping test methods
def wrap_test_methods(test_class_instance):
for name in dir(test_class_instance):
if name.startswith("test"):
method = getattr(test_class_instance, name)
if callable(method) and not hasattr(method, "_accelerate_is_purged_environment_wrapped"):
setattr(test_class_instance, name, wrap_function(method))
return test_class_instance

# Handle inheritance
wrap_test_methods(func_or_cls)
func_or_cls.__init_subclass__ = classmethod(lambda cls, **kw: wrap_test_methods(cls))
return func_or_cls
150 changes: 150 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
pad_across_processes,
pad_input_tensors,
patch_environment,
purge_accelerate_environment,
recursively_apply,
save,
send_to_device,
Expand Down Expand Up @@ -431,3 +432,152 @@ def test_has_offloaded_params(self):
remove_hook_from_module(model)
attach_align_device_hook(model, offload=True)
assert has_offloaded_params(model)


def set_dummy_accelerate_env_var():
"""Set an accelerate env var
This class emulates the behavior of, for instance, transformers.TrainingArguments, which is allowed to set
accelerate env vars but does not clean them up. E.g.
TrainingArguments(fp16=True, output_dir="/tmp/test")
leaves ACCELERATE_MIXED_PRECISION=fp16 as an env var.
"""
os.environ["ACCELERATE_SOME_ENV_VAR"] = "true"


@purge_accelerate_environment
class MyUnittest(unittest.TestCase):
def test_purge_env_vars_unittest_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_unittest_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


@unittest.skipIf(False, "dummy unittest wrapper")
@purge_accelerate_environment
@unittest.skipUnless(True, "dummy unittest wrapper")
class MyUnittestWithDecorators(unittest.TestCase):
def test_purge_env_vars_unittest_with_wrapper_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_unittest_with_wrapper_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ

@unittest.skipIf(False, "dummy unittest wrapper")
def test_purge_env_vars_unittest_with_wrapper_3(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ

@unittest.skipIf(True, "this is always skipped")
def test_purge_env_vars_unittest_with_wrapper_4(self):
# ensure that unittest markers still do their job
assert False


@purge_accelerate_environment
class _BaseCls(unittest.TestCase):
def test_purge_env_vars_unittest_with_inheritance_3(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


class MyUnittestWithInheritance(_BaseCls):
def test_purge_env_vars_unittest_with_inheritance_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_unittest_with_inheritance_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


@purge_accelerate_environment
class TestMyPytest:
def test_purge_env_vars_pytest_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_pytest_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


@pytest.fixture
def dummy_fixture():
pass


@pytest.mark.skipif(False, reason="dummy pytest wrapper")
@pytest.mark.usefixtures("dummy_fixture")
@purge_accelerate_environment
@pytest.mark.skipif(False, reason="dummy pytest wrapper")
@pytest.mark.usefixtures("dummy_fixture")
class TestPytestWithWrapper:
def test_purge_env_vars_pytest_with_wrapper_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_pytest_with_wrapper_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ

@pytest.mark.skipif(False, reason="dummy pytest wrapper")
@pytest.mark.usefixtures("dummy_fixture")
def test_purge_env_vars_pytest_with_wrapper_3(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ

@pytest.mark.skipif(True, reason="this is always skipped")
def test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped(self):
# ensure that pytest markers still do their job
assert False


@purge_accelerate_environment
class _PytestBaseCls:
def test_purge_env_vars_pytest_with_inheritance_3(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


class TestPytestWithInheritance(_PytestBaseCls):
def test_purge_env_vars_pytest_with_inheritance_1(self):
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ

def test_purge_env_vars_pytest_with_inheritance_2(self):
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


@purge_accelerate_environment
def test_purge_env_vars_standalone_1():
os.environ.pop("ACCELERATE_SOME_ENV_VAR", None)
set_dummy_accelerate_env_var()
assert "ACCELERATE_SOME_ENV_VAR" in os.environ


def test_purge_env_vars_standalone_2():
assert "ACCELERATE_SOME_ENV_VAR" not in os.environ


def test_purge_env_vars_restores_previous_values():
# Ensure that purge_accelerate_environment restores values of previous accelerate env vars and does not delete
# untouched env vars.
@purge_accelerate_environment
def dummy_func():
os.environ["ACCELERATE_SOME_ENV_VAR"] = "456"

os.environ["ACCELERATE_SOME_ENV_VAR"] = "1"
os.environ["ACCELERATE_ANOTHER_ENV_VAR"] = "2"

dummy_func()

assert os.environ["ACCELERATE_SOME_ENV_VAR"] == "1"
assert os.environ["ACCELERATE_ANOTHER_ENV_VAR"] == "2"

del os.environ["ACCELERATE_SOME_ENV_VAR"]
del os.environ["ACCELERATE_ANOTHER_ENV_VAR"]

0 comments on commit 29be478

Please sign in to comment.