Skip to content

Commit

Permalink
Add more callback hook points (#3023)
Browse files Browse the repository at this point in the history
* Add callback for compose function

* Add example compose callback and test

* Add docs

* Fix CI errors
  • Loading branch information
jesszzzz authored Feb 7, 2025
1 parent 618ab4b commit 0f03eb6
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 11 deletions.
48 changes: 46 additions & 2 deletions hydra/_internal/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import warnings
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from omegaconf import DictConfig, OmegaConf

from hydra.core.singleton import Singleton
from hydra.types import TaskFunction

if TYPE_CHECKING:
from hydra.core.utils import JobReturn


class CallbacksCache(metaclass=Singleton):
"""
A singleton class to cache callbacks so they are not reinstantiated during
compose config and start run.
"""

@staticmethod
def instance() -> "CallbacksCache":
return Singleton.instance(CallbacksCache) # type: ignore

cache: Dict[int, "Callbacks"]

def __init__(self) -> None:
self.cache = {}


class Callbacks:
def __init__(self, config: Optional[DictConfig] = None) -> None:
callbacks: List[Any]

def __init__(
self, config: Optional[DictConfig] = None, check_cache: bool = True
) -> None:
if config is None:
return
cache = CallbacksCache.instance().cache
if check_cache:
cached_callback = cache.get(id(config))
if cached_callback is not None:
self.callbacks = cached_callback.callbacks
return

self.callbacks = []
from hydra.utils import instantiate

if config is not None and OmegaConf.select(config, "hydra.callbacks"):
for params in config.hydra.callbacks.values():
self.callbacks.append(instantiate(params))
cache[id(config)] = self

def _notify(self, function_name: str, reverse: bool = False, **kwargs: Any) -> None:
callbacks = reversed(self.callbacks) if reverse else self.callbacks
Expand Down Expand Up @@ -63,3 +94,16 @@ def on_job_end(
reverse=True,
**kwargs,
)

def on_compose_config(
self,
config: DictConfig,
config_name: Optional[str],
overrides: List[str],
) -> None:
self._notify(
function_name="on_compose_config",
config=config,
config_name=config_name,
overrides=overrides,
)
15 changes: 15 additions & 0 deletions hydra/_internal/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_mode(
with_log_configuration=False,
run_mode=RunMode.MULTIRUN,
validate_sweep_overrides=False,
run_callback=False,
)
return cfg.hydra.mode
except Exception:
Expand Down Expand Up @@ -197,6 +198,7 @@ def show_cfg(
overrides=overrides,
run_mode=RunMode.RUN,
with_log_configuration=False,
run_callback=False,
)
HydraConfig.instance().set_config(cfg)
OmegaConf.set_readonly(cfg.hydra, None)
Expand Down Expand Up @@ -422,6 +424,7 @@ def _print_search_path(
overrides=overrides,
run_mode=run_mode,
with_log_configuration=False,
run_callback=False,
)
HydraConfig.instance().set_config(cfg)
cfg = self.get_sanitized_cfg(cfg, cfg_type="hydra")
Expand Down Expand Up @@ -501,6 +504,7 @@ def _print_config_info(
overrides=overrides,
run_mode=run_mode,
with_log_configuration=False,
run_callback=False,
)
)
HydraConfig.instance().set_config(cfg)
Expand Down Expand Up @@ -581,13 +585,17 @@ def compose_config(
with_log_configuration: bool = False,
from_shell: bool = True,
validate_sweep_overrides: bool = True,
run_callback: bool = True,
) -> DictConfig:
"""
:param config_name:
:param overrides:
:param run_mode: compose config for run or for multirun?
:param with_log_configuration: True to configure logging subsystem from the loaded config
:param from_shell: True if the parameters are passed from the shell. used for more helpful error messages
:param validate_sweep_overrides: True if sweep overrides should be validated
:param run_callback: True if the on_compose_config callback should be called, generally should always
be True except for internal use cases
:return:
"""

Expand All @@ -603,6 +611,13 @@ def compose_config(
global log
log = logging.getLogger(__name__)
self._print_debug_info(config_name, overrides, run_mode)
if run_callback:
callbacks = Callbacks(cfg, check_cache=False)
callbacks.on_compose_config(
config=cfg,
config_name=config_name,
overrides=overrides,
)
return cfg

def _print_plugins_info(
Expand Down
14 changes: 13 additions & 1 deletion hydra/experimental/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Any
from typing import Any, List, Optional

from omegaconf import DictConfig

Expand Down Expand Up @@ -63,3 +63,15 @@ def on_job_end(
See hydra.core.utils.JobReturn for more.
"""
...

def on_compose_config(
self,
config: DictConfig,
config_name: Optional[str],
overrides: List[str],
) -> None:
"""
Called during the compose phase and before the config is returned to the user.
config is the composed config with overrides applied.
"""
...
51 changes: 49 additions & 2 deletions hydra/experimental/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import copy
import logging
import pickle
from pathlib import Path
from typing import Any
from typing import Any, List, Optional

from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf, flag_override

from hydra.core.global_hydra import GlobalHydra
from hydra.core.utils import JobReturn, JobStatus
from hydra.experimental.callback import Callback
from hydra.types import RunMode


class LogJobReturnCallback(Callback):
Expand Down Expand Up @@ -58,3 +61,47 @@ def _save_pickle(self, obj: Any, filename: str, output_dir: Path) -> None:
assert output_dir is not None
with open(str(output_dir / filename), "wb") as file:
pickle.dump(obj, file, protocol=4)


class LogComposeCallback(Callback):
"""Log compose call, result, and debug info"""

def __init__(self) -> None:
self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")

def on_compose_config(
self,
config: DictConfig,
config_name: Optional[str],
overrides: List[str],
) -> None:
gh = GlobalHydra.instance()
config_loader = gh.config_loader()
config_dir = "unknown"
defaults_list = config_loader.compute_defaults_list(
config_name, overrides, RunMode.RUN
)
all_sources = config_loader.get_sources()
if config_name:
for src in all_sources:
if src.is_config(config_name):
config_dir = src.full_path()
break
if "hydra" in config:
config = copy.copy(config)
with flag_override(config, ["struct", "readonly"], [False, False]):
config.pop("hydra")
non_hydra_defaults = [
d.config_path
for d in defaults_list.defaults
if not d.package.startswith("hydra")
]
self.log.info(
f"""====
Composed config {config_dir}/{str(config_name)}
{OmegaConf.to_yaml(config)}
----
Includes overrides {overrides}
Used defaults {non_hydra_defaults}
===="""
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
age: 7
name: James Bond

hydra:
job:
name: test
callbacks:
log_compose:
_target_: hydra.experimental.callbacks.LogComposeCallback

defaults:
- config_schema
- _self_
- group: a
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
name: a
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import Dict

from omegaconf import MISSING

import hydra
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig


@dataclass
class Config:
age: int = MISSING
name: str = MISSING
group: Dict[str, str] = MISSING


ConfigStore.instance().store(name="config_schema", node=Config)
ConfigStore.instance().store(name="config_schema", node=Config, group="test")


@hydra.main(version_base=None, config_path=".", config_name="config")
def my_app(cfg: Config) -> None:
print(
f"job_name: {HydraConfig().get().job.name}, "
f"name: {cfg.name}, age: {cfg.age}, group: {cfg.group['name']}"
)


if __name__ == "__main__":
my_app()
21 changes: 21 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@
r"\[JOB\] on_job_start task_function: <function my_app at 0x[0-9a-fA-F]+>",
id="on_job_start_task_function",
),
param(
"tests/test_apps/app_with_callbacks/app_with_log_compose_callback/my_app.py",
["age=10"],
dedent(
"""\
[HYDRA] ====
Composed config .*tests.test_apps.app_with_callbacks.app_with_log_compose_callback.config
age: 10
name: James Bond
group:
name: a
----
Includes overrides \\[.*'age=10'.*\\]
Used defaults \\['config_schema', 'config', 'group/a'\\]
====
job_name: test, name: James Bond, age: 10, group: a
"""
),
id="on_compose_callback",
),
],
)
def test_app_with_callbacks(
Expand Down
28 changes: 22 additions & 6 deletions website/docs/experimental/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ class Callback:
See hydra.core.utils.JobReturn for more.
"""
...

def on_compose_config(
self,
config: DictConfig,
config_name: Optional[str],
overrides: List[str],
) -> None:
"""
Called during the compose phase and before the config is returned to the user.
config is the composed config with overrides applied.
"""
...
```
</details>

Expand Down Expand Up @@ -109,7 +121,7 @@ def my_app(cfg: DictConfig) -> None:

if __name__ == "__main__":
my_app()
```
```
</div>
<div className="col col--3" >

Expand All @@ -134,7 +146,7 @@ Job ended,uploading...
</div>
</div>

Now let's take a look at the configurations.
Now let's take a look at the configurations.

<div className="row">
<div className="col col--4">
Expand Down Expand Up @@ -177,11 +189,14 @@ hydra:


### Callback ordering
The `on_run_start` or `on_multirun_start` method will get called first,
followed by `on_job_start` (called once for each job).
The `on_compose_config` method will be called first, followed by
`on_run_start` or `on_multirun_start` method will get called first,
and then `on_job_start` (called once for each job).
After each job `on_job_end` is called, and finally either `on_run_end` or
`on_multirun_end` is called one time before the application exits.

When using the `compose` function directly, only `on_compose_config` will be called.

In the `hydra.callbacks` section of your config, you can use a list to register multiple callbacks. They will be called in the final composed order for `start` events and
in reversed order for `end` events. So, for example, suppose we have the following composed config:
```commandline title="python my_app.py --cfg hydra -p hydra.callbacks"
Expand All @@ -202,5 +217,6 @@ followed by `MyCallback1.on_job_end`.
### Example callbacks

We've included some example callbacks <GithubLink to="hydra/experimental/callbacks.py">here</GithubLink>:
- `LogJobReturnCallback` is especially useful for logging errors when running on a remote cluster (e.g. slurm.)
- `PickleJobInfoCallback` can be used to reproduce a Hydra job. See [here](/experimental/rerun.md) for more.
- `LogJobReturnCallback` is especially useful for logging errors when running on a remote cluster (e.g. slurm.)
- `PickleJobInfoCallback` can be used to reproduce a Hydra job. See [here](/experimental/rerun.md) for more.
- `LogComposeCallback` is useful for logging whenever a new config is composed and the overrides and context used.

0 comments on commit 0f03eb6

Please sign in to comment.