Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Contextual MAB param #167

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions golem/core/optimisers/adaptive/mab_agents/contextual_mab_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from functools import partial
from typing import Union, Sequence, Optional, List
from typing import Union, Sequence, Optional, List, Callable

import numpy as np
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
Expand All @@ -14,10 +14,18 @@

class ContextualMultiArmedBanditAgent(OperatorAgent):
""" Contextual Multi-Armed bandit. Observations can be encoded with simple context agent without
using NN to guarantee convergence. """
using NN to guarantee convergence.

:param actions: types of mutations
:param context_agent: function to convert observation to its embedding. Can be specified as
ContextAgentTypeEnum or as Callable function.
:param available_operations: available operations
:param n_jobs: n_jobs
:param enable_logging: bool logging flag
"""

def __init__(self, actions: Sequence[ActType],
context_agent_type: ContextAgentTypeEnum,
context_agent_type: Union[ContextAgentTypeEnum, Callable],
available_operations: List[str],
n_jobs: int = 1,
enable_logging: bool = True):
Expand All @@ -29,8 +37,9 @@ def __init__(self, actions: Sequence[ActType],
learning_policy=LearningPolicy.UCB1(alpha=1.25),
neighborhood_policy=NeighborhoodPolicy.Clusters(),
n_jobs=n_jobs)
self._context_agent = partial(ContextAgentsRepository.agent_class_by_id(context_agent_type),
available_operations=available_operations)
self._context_agent = context_agent_type if isinstance(context_agent_type, Callable) else \
partial(ContextAgentsRepository.agent_class_by_id(context_agent_type),
available_operations=available_operations)
self._is_fitted = False

def _initial_fit(self, obs: ObsType):
Expand Down
2 changes: 1 addition & 1 deletion golem/core/optimisers/genetic/operators/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _init_operator_agent(graph_gen_params: GraphGenerationParams,
agent = ContextualMultiArmedBanditAgent(
actions=parameters.mutation_types,
context_agent_type=parameters.context_agent_type,
available_operations=graph_gen_params.node_factory.available_nodes,
available_operations=graph_gen_params.node_factory.get_all_available_operations(),
n_jobs=requirements.n_jobs)
elif kind == MutationAgentTypeEnum.neural_bandit:
agent = NeuralContextualMultiArmedBanditAgent(actions=parameters.mutation_types,
Expand Down
15 changes: 14 additions & 1 deletion golem/core/optimisers/opt_node_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from abc import abstractmethod, ABC
from random import choice
from typing import Optional, Iterable
from typing import Optional, Iterable, List

from golem.core.optimisers.graph import OptNode

Expand Down Expand Up @@ -34,6 +34,13 @@ def get_node(self, **kwargs) -> Optional[OptNode]:
"""
pass

@abstractmethod
def get_all_available_operations(self) -> List[str]:
"""
Returns all available models and data operations.
"""
pass


class DefaultOptNodeFactory(OptNodeFactory):
"""Default node factory that either randomly selects
Expand All @@ -47,6 +54,12 @@ def __init__(self,
self.available_nodes = tuple(available_node_types) if available_node_types else None
self._num_node_types = num_node_types or 1000

def get_all_available_operations(self) -> Optional[List[str]]:
"""
Returns all available models and data operations.
"""
return self.available_nodes

def exchange_node(self, node: OptNode) -> OptNode:
return self.get_node()

Expand Down
Loading