Skip to content

Commit

Permalink
Add Orion Extension concept [OC-343]
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay committed Sep 24, 2021
1 parent da00153 commit eb5f3f9
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions src/orion/ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Defines extension mechanism for third party to hook into Orion"""


class EventDelegate:
def __init__(self, name, deferred=False, parent=None) -> None:
self.handlers = []
self.deferred_calls = []
self.name = name
self.parent = parent
self.deferred = deferred
event_manager.add(name, self)
self.bad_handlers = []

def remove(self, function) -> bool:
try:
self.handlers.remove(function)
return True
except ValueError:
return False

def add(self, function):
self.handlers.append(function)

def broadcast(self, *args, **kwargs):
if not self.deferred:
self._execute(args, kwargs)
return

self.deferred_calls.append((args, kwargs))

def _execute(self, args, kwargs):
for fun in self.handlers:
try:
fun(*args, _parent=self.parent, **kwargs)
except Exception as err:
event_manager.broadcast(self.name, fun, err, args=(args, kwargs))

def execute(self):
self.bad_handlers = []

for args, kwargs in self.deferred_calls:
self._execute(args, kwargs)



class OrionExtensionManager:
"""Manages third party extensions for Orion"""

def __init__(self):
self._events = {}

self._get_event('error')
self._get_event('start_experiment')
self._get_event('new_trial')
self._get_event('end_trial')
self._get_event('end_experiment')


def _get_event(self, key):
delegate = self._events.get(key)

if delegate is None:
delegate = EventDelegate(key)
self._events[key] = delegate

return delegate

def register(self, ext):
"""Register a new extensions"""
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.add(getattr(ext, name))

def unregister(self, ext):
"""Remove an extensions if it was already registered"""
for name, delegate in self._events.items():
if hasattr(ext, name):
delegate.remove(getattr(ext, name))


class OrionExtension:
"""Base orion extension interface you need to implement"""

def error(self, *args, **kwargs):
return

def start_experiment(self, *args, **kwargs):
return

def new_trial(self, *args, **kwargs):
return

def end_experiment(self, *args, **kwargs):
return

0 comments on commit eb5f3f9

Please sign in to comment.