Skip to content

Commit

Permalink
Enable method config registration from env variable (#1869)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkulhanek authored May 4, 2023
1 parent 9ace02b commit 971c9eb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/developer_guides/new_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ finally run the following to register the method,
pip install -e .
```

When developing a new method you don't always want to install your code as a package.
Instead, you may use the `NERFSTUDIO_METHOD_CONFIGS` environment variable to temporarily register your custom method.
```
export NERFSTUDIO_METHOD_CONFIGS="my-method=my_method.my_config:MyMethod"
```

## Running custom method

After registering your method you should be able to run the method with,
Expand Down
21 changes: 21 additions & 0 deletions nerfstudio/plugins/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Module that keeps all registered plugins and allows for plugin discovery.
"""

import importlib
import os
import sys
import typing as t

Expand All @@ -34,6 +36,7 @@
def discover_methods() -> t.Tuple[t.Dict[str, TrainerConfig], t.Dict[str, str]]:
"""
Discovers all methods registered using the `nerfstudio.method_configs` entrypoint.
And also methods in the NERFSTUDIO_METHOD_CONFIGS environment variable.
"""
methods = {}
descriptions = {}
Expand All @@ -48,4 +51,22 @@ def discover_methods() -> t.Tuple[t.Dict[str, TrainerConfig], t.Dict[str, str]]:
specification = t.cast(MethodSpecification, specification)
methods[specification.config.method_name] = specification.config
descriptions[specification.config.method_name] = specification.description

if "NERFSTUDIO_METHOD_CONFIGS" in os.environ:
try:
strings = os.environ["NERFSTUDIO_METHOD_CONFIGS"].split(",")
for definition in strings:
if not definition:
continue
name, path = definition.split("=")
CONSOLE.print(f"[bold green]Info: Loading method {name} from environment variable")
module, config_name = path.split(":")
method_config = getattr(importlib.import_module(module), config_name)
assert isinstance(method_config, MethodSpecification)
methods[name] = method_config.config
descriptions[name] = method_config.description
except Exception: # pylint: disable=broad-except
CONSOLE.print_exception()
CONSOLE.print("[bold red]Error: Could not load methods from environment variable NERFSTUDIO_METHOD_CONFIGS")

return methods, descriptions
21 changes: 21 additions & 0 deletions tests/plugins/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the nerfstudio.plugins.registry module.
"""
import os
import sys

from nerfstudio.engine.trainer import TrainerConfig
Expand Down Expand Up @@ -49,3 +50,23 @@ def entry_points(group=None):
finally:
# Revert mock
registry.entry_points = entry_points_backup


def test_discover_methods_from_environment_variable():
"""Tests if a custom method from env variable gets properly registered using the discover_methods method"""
old_env = None
try:
old_env = os.environ.get("NERFSTUDIO_METHOD_CONFIGS", None)
os.environ["NERFSTUDIO_METHOD_CONFIGS"] = "test-method-env=test_registry:TestConfig"

# Discover plugins
methods, _ = registry.discover_methods()
assert "test-method-env" in methods
config = methods["test-method-env"]
assert isinstance(config, TrainerConfig)
finally:
# Revert mock
if old_env is not None:
os.environ["NERFSTUDIO_METHOD_CONFIGS"] = old_env
else:
del os.environ["NERFSTUDIO_METHOD_CONFIGS"]

0 comments on commit 971c9eb

Please sign in to comment.