From eb5f3f904fdf64f646b9bf51f238031bbf83dd1b Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Fri, 24 Sep 2021 13:28:34 -0400 Subject: [PATCH] Add Orion Extension concept [OC-343] --- src/orion/ext/__init__.py | 95 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/orion/ext/__init__.py diff --git a/src/orion/ext/__init__.py b/src/orion/ext/__init__.py new file mode 100644 index 0000000000..3f34861bbe --- /dev/null +++ b/src/orion/ext/__init__.py @@ -0,0 +1,95 @@ +"""Defines extension mechanism for third party to hook into Orion""" + + +class EventDelegate: + def __init__(self, name, deferred=False, parent=None) -> None: + self.handlers = [] + self.deferred_calls = [] + self.name = name + self.parent = parent + self.deferred = deferred + event_manager.add(name, self) + self.bad_handlers = [] + + def remove(self, function) -> bool: + try: + self.handlers.remove(function) + return True + except ValueError: + return False + + def add(self, function): + self.handlers.append(function) + + def broadcast(self, *args, **kwargs): + if not self.deferred: + self._execute(args, kwargs) + return + + self.deferred_calls.append((args, kwargs)) + + def _execute(self, args, kwargs): + for fun in self.handlers: + try: + fun(*args, _parent=self.parent, **kwargs) + except Exception as err: + event_manager.broadcast(self.name, fun, err, args=(args, kwargs)) + + def execute(self): + self.bad_handlers = [] + + for args, kwargs in self.deferred_calls: + self._execute(args, kwargs) + + + +class OrionExtensionManager: + """Manages third party extensions for Orion""" + + def __init__(self): + self._events = {} + + self._get_event('error') + self._get_event('start_experiment') + self._get_event('new_trial') + self._get_event('end_trial') + self._get_event('end_experiment') + + + def _get_event(self, key): + delegate = self._events.get(key) + + if delegate is None: + delegate = EventDelegate(key) + self._events[key] = delegate + + return delegate + + def register(self, ext): + """Register a new extensions""" + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.add(getattr(ext, name)) + + def unregister(self, ext): + """Remove an extensions if it was already registered""" + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.remove(getattr(ext, name)) + + +class OrionExtension: + """Base orion extension interface you need to implement""" + + def error(self, *args, **kwargs): + return + + def start_experiment(self, *args, **kwargs): + return + + def new_trial(self, *args, **kwargs): + return + + def end_experiment(self, *args, **kwargs): + return +