diff --git a/golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py b/golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py index 1dc9a55b..3a20ac24 100644 --- a/golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py +++ b/golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py @@ -1,6 +1,6 @@ import random from functools import partial -from typing import Union, Sequence, Optional, List +from typing import Union, Sequence, Optional, List, Callable import numpy as np from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy @@ -14,10 +14,18 @@ class ContextualMultiArmedBanditAgent(OperatorAgent): """ Contextual Multi-Armed bandit. Observations can be encoded with simple context agent without - using NN to guarantee convergence. """ + using NN to guarantee convergence. + + :param actions: types of mutations + :param context_agent: function to convert observation to its embedding. Can be specified as + ContextAgentTypeEnum or as Callable function. + :param available_operations: available operations + :param n_jobs: n_jobs + :param enable_logging: bool logging flag + """ def __init__(self, actions: Sequence[ActType], - context_agent_type: ContextAgentTypeEnum, + context_agent_type: Union[ContextAgentTypeEnum, Callable], available_operations: List[str], n_jobs: int = 1, enable_logging: bool = True): @@ -29,8 +37,9 @@ def __init__(self, actions: Sequence[ActType], learning_policy=LearningPolicy.UCB1(alpha=1.25), neighborhood_policy=NeighborhoodPolicy.Clusters(), n_jobs=n_jobs) - self._context_agent = partial(ContextAgentsRepository.agent_class_by_id(context_agent_type), - available_operations=available_operations) + self._context_agent = context_agent_type if isinstance(context_agent_type, Callable) else \ + partial(ContextAgentsRepository.agent_class_by_id(context_agent_type), + available_operations=available_operations) self._is_fitted = False def _initial_fit(self, obs: ObsType): diff --git a/golem/core/optimisers/genetic/operators/mutation.py b/golem/core/optimisers/genetic/operators/mutation.py index aefb9ca9..6044e789 100644 --- a/golem/core/optimisers/genetic/operators/mutation.py +++ b/golem/core/optimisers/genetic/operators/mutation.py @@ -55,7 +55,7 @@ def _init_operator_agent(graph_gen_params: GraphGenerationParams, agent = ContextualMultiArmedBanditAgent( actions=parameters.mutation_types, context_agent_type=parameters.context_agent_type, - available_operations=graph_gen_params.node_factory.available_nodes, + available_operations=graph_gen_params.node_factory.get_all_available_operations(), n_jobs=requirements.n_jobs) elif kind == MutationAgentTypeEnum.neural_bandit: agent = NeuralContextualMultiArmedBanditAgent(actions=parameters.mutation_types, diff --git a/golem/core/optimisers/opt_node_factory.py b/golem/core/optimisers/opt_node_factory.py index 89e8c031..c23af6d8 100644 --- a/golem/core/optimisers/opt_node_factory.py +++ b/golem/core/optimisers/opt_node_factory.py @@ -1,7 +1,7 @@ import random from abc import abstractmethod, ABC from random import choice -from typing import Optional, Iterable +from typing import Optional, Iterable, List from golem.core.optimisers.graph import OptNode @@ -34,6 +34,13 @@ def get_node(self, **kwargs) -> Optional[OptNode]: """ pass + @abstractmethod + def get_all_available_operations(self) -> List[str]: + """ + Returns all available models and data operations. + """ + pass + class DefaultOptNodeFactory(OptNodeFactory): """Default node factory that either randomly selects @@ -47,6 +54,12 @@ def __init__(self, self.available_nodes = tuple(available_node_types) if available_node_types else None self._num_node_types = num_node_types or 1000 + def get_all_available_operations(self) -> Optional[List[str]]: + """ + Returns all available models and data operations. + """ + return self.available_nodes + def exchange_node(self, node: OptNode) -> OptNode: return self.get_node()