diff --git a/custom_components/stateful_scenes/StatefulScenes.py b/custom_components/stateful_scenes/StatefulScenes.py index 92f2ff1..01fda02 100644 --- a/custom_components/stateful_scenes/StatefulScenes.py +++ b/custom_components/stateful_scenes/StatefulScenes.py @@ -219,6 +219,7 @@ def __init__(self, hass: HomeAssistant, scene_conf: dict) -> None: self._debounce_time: float = 0 self.callback = None + self.callback_funcs = {} self.schedule_update = None self.states = {entity_id: False for entity_id in self.entities} self.restore_states = {entity_id: None for entity_id in self.entities} @@ -295,8 +296,12 @@ def set_restore_on_deactivate(self, restore_on_deactivate): """Set the restore on deactivate flag.""" self._restore_on_deactivate = restore_on_deactivate - def register_callback(self, state_change_func, schedule_update_func): + def register_callback(self): """Register callback.""" + schedule_update_func = self.callback_funcs.get("schedule_update_func", None) + state_change_func = self.callback_funcs.get("state_change_func", None) + if schedule_update_func is None or state_change_func is None: + raise ValueError("No callback functions provided for scene.") self.schedule_update = schedule_update_func self.callback = state_change_func( self.hass, self.entities.keys(), self.update_callback diff --git a/custom_components/stateful_scenes/switch.py b/custom_components/stateful_scenes/switch.py index 76e049d..a902f6e 100644 --- a/custom_components/stateful_scenes/switch.py +++ b/custom_components/stateful_scenes/switch.py @@ -105,6 +105,10 @@ def __init__(self, scene) -> None: self._icon = scene.icon self._attr_unique_id = f"stateful_{scene.id}" + self._scene.callback_funcs = { + "state_change_func":async_track_state_change_event, + "schedule_update_func":self.schedule_update_ha_state + } self.register_callback() @property @@ -156,10 +160,7 @@ def update(self) -> None: def register_callback(self) -> None: """Register callback to update hass when state changes.""" - self._scene.register_callback( - state_change_func=async_track_state_change_event, - schedule_update_func=self.schedule_update_ha_state, - ) + self._scene.register_callback() def unregister_callback(self) -> None: """Unregister callback."""