diff --git a/.gitignore b/.gitignore index ea2b8d9..2ed4672 100644 --- a/.gitignore +++ b/.gitignore @@ -128,10 +128,12 @@ dmypy.json # Pyre type checker .pyre/ - # High-fly cluster config .hfai .hfignore # PyCharm .idea + +# Models +models/ \ No newline at end of file diff --git a/hironaka/README.md b/hironaka/README.md new file mode 100644 index 0000000..bfd7efe --- /dev/null +++ b/hironaka/README.md @@ -0,0 +1,16 @@ +# hironaka +This is the base folder of the hironaka library. + +Submodules: + + - [hironaka.core](core) + - [hironaka.gym_env](gym_env) + - [hironaka.policy](policy) + - [hironaka.policy_players](policy_players) + - [hironaka.src](src) + - [hironaka.validator](validator) + +Python scripts and non-essential functions: + + - [.train/](../train) + - [.util/](util) \ No newline at end of file diff --git a/hironaka/abs/PointsNumpy.py b/hironaka/abs/PointsNumpy.py deleted file mode 100644 index 555db24..0000000 --- a/hironaka/abs/PointsNumpy.py +++ /dev/null @@ -1,8 +0,0 @@ -from hironaka.abs.PointsBase import PointsBase - - -class PointsNumpy(PointsBase): # INCOMPLETE - """ - Storing points using numpy arrays. - """ - pass diff --git a/hironaka/abs/__init__.py b/hironaka/abs/__init__.py deleted file mode 100644 index 2dabb6e..0000000 --- a/hironaka/abs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .Points import * diff --git a/hironaka/agent.py b/hironaka/agent.py index 6cddf1d..70c944b 100644 --- a/hironaka/agent.py +++ b/hironaka/agent.py @@ -1,8 +1,10 @@ import abc +from typing import List import numpy as np -from .abs import Points +from .core import Points +from .policy import Policy class Agent(abc.ABC): @@ -33,3 +35,21 @@ def move(self, points: Points, coords, inplace=True): points.shift(coords, actions) points.get_newton_polytope() return actions + + +class PolicyAgent(Agent): + def __init__(self, policy: Policy): + self._policy = policy + + def move(self, points: Points, coords: List[List[int]], inplace=True): + assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks? + + features = points.get_features() + + actions = self._policy.predict((features, coords)) + + if inplace: + points.shift(coords, actions) + points.get_newton_polytope() + + return actions diff --git a/hironaka/abs/Points.py b/hironaka/core/Points.py similarity index 66% rename from hironaka/abs/Points.py rename to hironaka/core/Points.py index 721aec5..ce2ddad 100644 --- a/hironaka/abs/Points.py +++ b/hironaka/core/Points.py @@ -2,23 +2,29 @@ import numpy as np -from hironaka.abs.PointsBase import PointsBase -from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, reposition_lst +from hironaka.core.PointsBase import PointsBase +from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, reposition_lst, \ + get_newton_polytope_approx_lst class Points(PointsBase): """ When dealing with small batches, small dimension and small point numbers, list is much better than numpy. """ - config_keys = ['value_threshold'] + subcls_config_keys = ['value_threshold', 'use_precise_newton_polytope'] + copied_attributes = ['distinguished_points'] def __init__(self, points: Union[List[List[List[int]]], np.ndarray], value_threshold: Optional[int] = 1e8, + use_precise_newton_polytope: Optional[bool] = False, + distinguished_points: Optional[Union[List[int], None]] = None, config_kwargs: Optional[Dict[str, Any]] = None, **kwargs): config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs} self.value_threshold = value_threshold + self.use_precise_newton_polytope = use_precise_newton_polytope + self.distinguished_points = distinguished_points # Be lenient and allow numpy array as input. # The input might already be -1 padded arrays. Thus, we do a thorough check to clean that up. @@ -61,7 +67,34 @@ def _reposition(self, points: Any, inplace: Optional[bool] = True): return reposition_lst(points, inplace=inplace) def _get_newton_polytope(self, points: Any, inplace: Optional[bool] = True): - return get_newton_polytope_lst(points, inplace=inplace, get_ended=False) + # Mark distinguished points + if self.distinguished_points is not None: + # Apply marks to the distinguished points before the operation + for b in range(self.batch_size): + if self.distinguished_points[b] is None: + continue + self.points[b][self.distinguished_points[b]].append('d') + + if self.use_precise_newton_polytope: + result = get_newton_polytope_lst(points, inplace=inplace) + else: + result = get_newton_polytope_approx_lst(points, inplace=inplace, get_ended=False) + + # Recover the locations of distinguished points + if self.distinguished_points is not None: + transformed_points = points if inplace else result + for b in range(self.batch_size): + if self.distinguished_points[b] is None: + continue + distinguished_point_index = None + for i in range(len(transformed_points[b])): + if transformed_points[b][i][-1] == 'd': + distinguished_point_index = i + transformed_points[b][i].pop() + break + self.distinguished_points[b] = distinguished_point_index + + return result def _get_shape(self, points: Any): return get_shape(points) diff --git a/hironaka/abs/PointsBase.py b/hironaka/core/PointsBase.py similarity index 89% rename from hironaka/abs/PointsBase.py rename to hironaka/core/PointsBase.py index 63ebf5b..a22511b 100644 --- a/hironaka/abs/PointsBase.py +++ b/hironaka/core/PointsBase.py @@ -1,4 +1,5 @@ import abc +from copy import deepcopy from typing import Any, Optional, List @@ -35,7 +36,11 @@ class PointsBase(abc.ABC): """ # You MUST define `config_keys` when inheriting. # Keys in `config_keys` will be tracked when calling the `copy()` method. - config_keys: List[str] + subcls_config_keys: List[str] + # Keys in `copied_attributes` will be directly copied during `copy()`. They MUST be initialized. + copied_attributes: List[str] + # Keys in `base_attributes` will be copied. But they are shared in all subclasses and do not need to re-initialize. + base_attributes = ['ended_each_batch', 'ended', 'batch_size', 'max_num_points', 'dimension'] def __init__(self, points: Any, @@ -53,10 +58,6 @@ def __init__(self, else: self.config = kwargs - # Update keys if modified or created in subclass - for key in self.config_keys: - self.config[key] = getattr(self, key) - # Check the shape of `points`. shape = self._get_shape(points) if len(shape) == 2: @@ -66,7 +67,6 @@ def __init__(self, points = self._add_batch_axis(points) if len(shape) != 3: raise Exception("input dimension must be 2 or 3.") - self.points = points self.batch_size = self.config.get('points_batch_size', shape[0]) @@ -81,6 +81,13 @@ def __init__(self, # will also be updated on point-changing modifications including `get_newton_polytope` self.ended_each_batch = [False] * self.batch_size + # Update keys in `self.copied_attributes` + for key in self.copied_attributes: + if hasattr(self, key): + self.config[key] = getattr(self, key) + else: + raise Exception("Must initialize keys in 'subcls_config_keys' before calling super().__init__.") + def copy(self, points: Optional[Any] = None): """ Copy the object. @@ -93,7 +100,14 @@ def copy(self, points: Optional[Any] = None): new_points = self._points_copy(self.points) else: new_points = self._points_copy(points) - return self.__class__(new_points, **self.config) + new_points = self.__class__(new_points, **self.config) + + for key in self.copied_attributes + self.base_attributes: + if hasattr(self, key): + setattr(new_points, key, deepcopy(getattr(self, key))) + else: + raise Exception(f"Attribute {key} is not initialized.") + return new_points def shift(self, coords: List[List[int]], axis: List[int], inplace=True): """ @@ -129,7 +143,7 @@ def get_newton_polytope(self, inplace=True): return None else: new_points = self.copy(points=r) - new_points.ended_each_batch = ended_each_batch + new_points.ended_each_batch = ended_each_batch # TODO: duplicate? new_points.ended = ended return new_points diff --git a/hironaka/core/PointsNumpy.py b/hironaka/core/PointsNumpy.py new file mode 100644 index 0000000..25d6a67 --- /dev/null +++ b/hironaka/core/PointsNumpy.py @@ -0,0 +1,8 @@ +from hironaka.core.PointsBase import PointsBase + + +class PointsNumpy(PointsBase): # TODO:INCOMPLETE + """ + Storing points using numpy arrays. + """ + pass diff --git a/hironaka/abs/PointsTensor.py b/hironaka/core/PointsTensor.py similarity index 70% rename from hironaka/abs/PointsTensor.py rename to hironaka/core/PointsTensor.py index b65f142..371cfcb 100644 --- a/hironaka/abs/PointsTensor.py +++ b/hironaka/core/PointsTensor.py @@ -3,40 +3,45 @@ import numpy as np import torch -from hironaka.abs.PointsBase import PointsBase -from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, get_batched_padded_array -from hironaka.src._torch_ops import shift_torch, get_newton_polytope_torch, reposition_torch +from hironaka.core.PointsBase import PointsBase +from hironaka.src import get_batched_padded_array, rescale_torch +from hironaka.src import shift_torch, get_newton_polytope_torch, reposition_torch class PointsTensor(PointsBase): - config_keys = ['value_threshold', 'device_key', 'padded_value'] + subcls_config_keys = ['value_threshold', 'device_key', 'padding_value'] + copied_attributes = ['distinguished_points'] def __init__(self, points: Union[torch.Tensor, List[List[List[int]]], np.ndarray], value_threshold: Optional[int] = 1e8, device_key: Optional[str] = 'cpu', - padded_value: Optional[float] = -1.0, + padding_value: Optional[float] = -1.0, + distinguished_points: Optional[List[int]] = None, config_kwargs: Optional[Dict[str, Any]] = None, **kwargs): config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs} self.value_threshold = value_threshold - # It's better to require a fixed shape of the tensor implementation. + assert padding_value < 0, f"'padding_value' must be a negative number. Got {padding_value} instead." if isinstance(points, list): points = torch.FloatTensor( get_batched_padded_array(points, new_length=config['max_num_points'], - constant_value=config.get('padded_value', -1))) - elif isinstance(points, (torch.Tensor, np.ndarray)): + constant_value=padding_value)) + elif isinstance(points, np.ndarray): points = torch.FloatTensor(points) + elif isinstance(points, torch.Tensor): + points = points.type(torch.float32) else: raise Exception(f"Input must be a Tensor, a numpy array or a nested list. Got {type(points)}.") self.batch_size, self.max_num_points, self.dimension = points.shape self.device_key = device_key - self.padded_value = padded_value + self.padding_value = padding_value + self.distinguished_points = distinguished_points super().__init__(points, **config) self.device = torch.device(self.device_key) @@ -62,19 +67,19 @@ def _shift(self, coords: List[List[int]], axis: List[int], inplace: Optional[bool] = True): - return shift_torch(points, coords, axis, inplace=inplace) + return shift_torch(points, coords, axis, inplace=inplace, padding_value=self.padding_value) def _get_newton_polytope(self, points: torch.Tensor, inplace: Optional[bool] = True): - return get_newton_polytope_torch(points, inplace=inplace) + return get_newton_polytope_torch(points, inplace=inplace, padding_value=self.padding_value) def _get_shape(self, points: torch.Tensor): return points.shape def _reposition(self, points: torch.Tensor, inplace: Optional[bool] = True): - return reposition_torch(points, inplace=inplace) + return reposition_torch(points, inplace=inplace, padding_value=self.padding_value) def _rescale(self, points: torch.Tensor, inplace: Optional[bool] = True): - return points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1)) + r = rescale_torch(points, inplace=inplace, padding_value=self.padding_value) def _points_copy(self, points: torch.Tensor): return points.clone().detach() diff --git a/hironaka/core/README.md b/hironaka/core/README.md new file mode 100644 index 0000000..cf4296a --- /dev/null +++ b/hironaka/core/README.md @@ -0,0 +1,64 @@ +# hironaka.core +This is the core functionality whose classes + - save collections of points, + - perform transformations (Newton polytope, shift, rescale, etc.), + - provide features, states, etc. + +The base class is `PointsBase`, and the subclasses are currently `Points` and `PointsTensor`. + +`PointsNumpy` is currently unnecessary and the implementation is postponed. + +## .PointsBase +This interface is an abstraction of collection of points used in the Hironaka games and its variations. + +A couple important notes: + - The points are stored as 3d objects (nested lists, tensors, etc.). The 3 axis represent: + - (batches, points, point coordinates) + - Subclasses must define `subcls_config_keys, copied_attributes`. + - `subcls_config_keys` defines the config keys that will be tracked and initialized when creating a copy object in `.copy()` method. In other words, they are configs that stay unchanged throughout space transformations. + - `copied_attributes` defines the keys of the attributes that will be directly copied after initialization in `.copy()`. In other words, they may be changed during space transformations and is partially the information of the state (e.g., `.ended`). + - Shape of points can be specified with config parameters `points_batch_size, dimension, max_number_points`. If they are not given, it will look over the point data and initialize `.batch_size, .dimension, .max_num_points` attributes. + +Must implement: +` +_get_shape +_get_newton_polytope +_reposition +_shift +_rescale +_point_copy +_add_batch_axis +_get_batch_ended +` + +Feel free to override: +` +get_features +_get_max_num_points +` + +## .Points +It stores the points in nested lists and perform list-based transformations. The nested lists do not have to be of uniform shape (not padded). +For example: + +``` +[ + [ + (7, 5, 3, 8), (8, 9, 8, 18), (8, 3, 17, 8), + (11, 11, 1, 19), (11, 12, 18, 6), (16, 11, 5, 6) + ], + [ + (0, 1, 0, 1), (0, 2, 0, 0), (1, 0, 0, 1), + (1, 0, 1, 0), (1, 1, 0, 0) + ] +] +``` + +This is a batch of 2 separate sets of points. They are of different sizes. + +## .PointsTensor +It performs everything using PyTorch tensors. + +Major differences between `PointsTensor` and `Points` + - `PointsTensor.points` have a fixed shape. Removed points will be replaced by `padded_value` which must be a negative number. + - In `Points.points`, the order of points may change during `get_newton_polytope()`. But for `PointsTensor.points`, the order will never change. When points are removed, we still maintain the original orders without moving surviving points next to each other. As a result, `distinguished_points` marks the distinguished point in each set of points, but is never changed under any operations. \ No newline at end of file diff --git a/hironaka/core/__init__.py b/hironaka/core/__init__.py new file mode 100644 index 0000000..c4e1db6 --- /dev/null +++ b/hironaka/core/__init__.py @@ -0,0 +1,2 @@ +from .Points import * +from .PointsTensor import * diff --git a/hironaka/cpp/README.md b/hironaka/cpp/README.md new file mode 100644 index 0000000..d271cba --- /dev/null +++ b/hironaka/cpp/README.md @@ -0,0 +1,3 @@ +# C++ modules (depreciated) +This was originally written for numpy operations but is currently not in use. +We leave it here for potential future revival. \ No newline at end of file diff --git a/hironaka/gameHironaka.py b/hironaka/gameHironaka.py index 9c77cdd..473b34e 100644 --- a/hironaka/gameHironaka.py +++ b/hironaka/gameHironaka.py @@ -1,7 +1,7 @@ import logging from typing import Optional, Union -from .abs import Points +from .core import Points from .agent import Agent from .game import Game from .host import Host diff --git a/hironaka/gameThom.py b/hironaka/gameThom.py deleted file mode 100644 index e74309a..0000000 --- a/hironaka/gameThom.py +++ /dev/null @@ -1,21 +0,0 @@ -class GameThom: - def __init__(self, points, host, agent): - self.state = points - self.host = host - self.agent = agent - self.coordHistory = [] - self.moveHistory = [] - self.stopped = False - - def step(self): - if self.stopped: - return - coords = self.host.select_coord(self.state) - new_state, action = self.agent.move(self.state, coords) - - self.state = new_state - self.coordHistory.append(coords) - self.moveHistory.append(action) - - if len(self.state) == 1: - self.stopped = True diff --git a/hironaka/envs/HironakaAgentEnv.py b/hironaka/gym_env/HironakaAgentEnv.py similarity index 98% rename from hironaka/envs/HironakaAgentEnv.py rename to hironaka/gym_env/HironakaAgentEnv.py index 586a915..5e82004 100644 --- a/hironaka/envs/HironakaAgentEnv.py +++ b/hironaka/gym_env/HironakaAgentEnv.py @@ -4,7 +4,7 @@ from gym import spaces from hironaka.agent import Agent -from hironaka.envs.HironakaBase import HironakaBase +from hironaka.gym_env.HironakaBase import HironakaBase from hironaka.src import decode_action diff --git a/hironaka/envs/HironakaBase.py b/hironaka/gym_env/HironakaBase.py similarity index 98% rename from hironaka/envs/HironakaBase.py rename to hironaka/gym_env/HironakaBase.py index b29ddf2..7b6706b 100644 --- a/hironaka/envs/HironakaBase.py +++ b/hironaka/gym_env/HironakaBase.py @@ -5,9 +5,8 @@ import numpy as np from gym import spaces -from hironaka.abs import Points -from hironaka.src import get_padded_array, get_gym_version_in_float -from hironaka.util import generate_points +from hironaka.core import Points +from hironaka.src import get_padded_array, get_gym_version_in_float, generate_points ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") diff --git a/hironaka/envs/HironakaHostEnv.py b/hironaka/gym_env/HironakaHostEnv.py similarity index 98% rename from hironaka/envs/HironakaHostEnv.py rename to hironaka/gym_env/HironakaHostEnv.py index 6de4846..0fad3a5 100644 --- a/hironaka/envs/HironakaHostEnv.py +++ b/hironaka/gym_env/HironakaHostEnv.py @@ -3,7 +3,7 @@ import numpy as np from gym import spaces -from hironaka.envs.HironakaBase import HironakaBase +from hironaka.gym_env.HironakaBase import HironakaBase from hironaka.host import Host diff --git a/hironaka/gym_env/README.md b/hironaka/gym_env/README.md new file mode 100644 index 0000000..f1cb696 --- /dev/null +++ b/hironaka/gym_env/README.md @@ -0,0 +1,8 @@ +# hironaka.gym_env +This contains gym wrappers of host/agent environments. + +Note: + - The naming is indeed confusing, and I do not see a better way of naming them. **Please remember** + - HironakaAgentEnv: a gym environment that takes an `Agent` as an initialization parameter. The `Agent` object is fixed throughout the game, and receives actions from an unknown `Host`. + - HironakaHostEnv: a gym environment that takes an `Host` as an initialization parameter. The `Host` object is fixed throughout the game, and receives actions from an unknown `Agent`. + diff --git a/hironaka/envs/__init__.py b/hironaka/gym_env/__init__.py similarity index 100% rename from hironaka/envs/__init__.py rename to hironaka/gym_env/__init__.py diff --git a/hironaka/host.py b/hironaka/host.py index 20205c4..08d5c91 100644 --- a/hironaka/host.py +++ b/hironaka/host.py @@ -1,9 +1,11 @@ import abc from itertools import combinations +from typing import Optional import numpy as np -from .abs import Points +from .core import Points +from .policy import Policy class Host(abc.ABC): @@ -61,3 +63,21 @@ def select_coord(self, points: Points, debug=False): result.append([0, 1]) return result + + +class PolicyHost(Host): + def __init__(self, + policy: Policy, + use_discrete_actions_for_host: Optional[bool] = False, + **kwargs): + self._policy = policy + self.use_discrete_actions_for_host = kwargs.get('use_discrete_actions_for_host', use_discrete_actions_for_host) + + def select_coord(self, points: Points, debug=False): + features = points.get_features() + + coords = self._policy.predict(features) # return multi-binary array + result = [] + for b in range(coords.shape[0]): + result.append(np.where(coords[b] == 1)[0]) + return result diff --git a/hironaka/policy/README.md b/hironaka/policy/README.md new file mode 100644 index 0000000..d7f0872 --- /dev/null +++ b/hironaka/policy/README.md @@ -0,0 +1,14 @@ +# hironaka.policy +A Policy is a function mapping an observation to an action. This is a wrapper class of such a function. The inside can be a neural network, a hardcoded strategy, or whatever. + +Need to implement: +` +__init__ +predict +` + +Note that `input_preprocess_for_host` and `input_preprocess_for_agent` are merely helper functions for input preprocessing of list-based observations (e.g., class `Points`). +Feel free to override them for different purposes. + +## .NNPolicy +It wraps a neural network inside. \ No newline at end of file diff --git a/hironaka/policy_players/PolicyAgent.py b/hironaka/policy_players/PolicyAgent.py deleted file mode 100644 index 0b42a4b..0000000 --- a/hironaka/policy_players/PolicyAgent.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import List - -from hironaka.abs import Points -from hironaka.agent import Agent -from hironaka.policy.Policy import Policy - - -class PolicyAgent(Agent): - def __init__(self, policy: Policy): - self._policy = policy - - def move(self, points: Points, coords: List[List[int]], inplace=True): - assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks? - - features = points.get_features() - - actions = self._policy.predict((features, coords)) - - if inplace: - points.shift(coords, actions) - points.get_newton_polytope() - - return actions diff --git a/hironaka/policy_players/PolicyHost.py b/hironaka/policy_players/PolicyHost.py deleted file mode 100644 index 27b9b50..0000000 --- a/hironaka/policy_players/PolicyHost.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Optional - -import numpy as np - -from hironaka.abs import Points -from hironaka.host import Host -from hironaka.policy.Policy import Policy - - -class PolicyHost(Host): - def __init__(self, - policy: Policy, - use_discrete_actions_for_host: Optional[bool] = False, - **kwargs): - self._policy = policy - self.use_discrete_actions_for_host = kwargs.get('use_discrete_actions_for_host', use_discrete_actions_for_host) - - def select_coord(self, points: Points, debug=False): - features = points.get_features() - - coords = self._policy.predict(features) # return multi-binary array - result = [] - for b in range(coords.shape[0]): - result.append(np.where(coords[b] == 1)[0]) - return result diff --git a/hironaka/policy_players/__init__.py b/hironaka/policy_players/__init__.py deleted file mode 100644 index d9a06dc..0000000 --- a/hironaka/policy_players/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .PolicyHost import PolicyHost -from .PolicyAgent import PolicyAgent diff --git a/hironaka/src/README.md b/hironaka/src/README.md new file mode 100644 index 0000000..c8f15d6 --- /dev/null +++ b/hironaka/src/README.md @@ -0,0 +1,2 @@ +# hironaka.src +This is the collection of core functions that implement Newton polytope, shifting, rescaling, etc. for different data types, as well as a bunch of helper functions. \ No newline at end of file diff --git a/hironaka/src/__init__.py b/hironaka/src/__init__.py index dd730ee..4c63f16 100644 --- a/hironaka/src/__init__.py +++ b/hironaka/src/__init__.py @@ -1,3 +1,3 @@ from ._list_ops import * -from ._np_ops import * from ._snippets import * +from ._torch_ops import * diff --git a/hironaka/src/_list_ops.py b/hironaka/src/_list_ops.py index 0d08bfe..66355e1 100644 --- a/hironaka/src/_list_ops.py +++ b/hironaka/src/_list_ops.py @@ -1,9 +1,12 @@ from typing import List +import numpy as np + from ._snippets import get_shape +from scipy.spatial import ConvexHull -def get_newton_polytope_approx_lst(points: List[List[List[int]]], inplace=True, get_ended=False): +def get_newton_polytope_approx_lst(points: List[List[List[float]]], inplace=True, get_ended=False): """ A simple-minded quick-and-dirty method to obtain an approximation of Newton Polytope disregarding convexity. Returns: @@ -51,15 +54,36 @@ def get_newton_polytope_approx_lst(points: List[List[List[int]]], inplace=True, return (new_points, ended_each_batch) if get_ended else new_points -def get_newton_polytope_lst(points: List[List[List[int]]], inplace=True, get_ended=False): +def get_newton_polytope_lst(points: List[List[List[float]]], inplace=True): """ Get the Newton Polytope for a set of points. + TODO: this is perhaps a slow implementation. Must improve! """ - return get_newton_polytope_approx_lst(points, inplace=inplace, get_ended=get_ended) - # TODO: perhaps change to a more precise algo to obtain Newton Polytope + assert len(get_shape(points)) == 3 + + result = [] + for pts in points: + pts_np = np.array(pts) + maximum = np.max(pts_np) + dimension = pts_np.shape[1] + + # Add points that are very far-away. + extra = np.full((dimension, dimension), maximum * 2) * (~np.diag([True] * dimension)) + enhanced_pts = np.concatenate((pts_np, extra), axis=0) + + vertices = ConvexHull(enhanced_pts).vertices + newton_polytope_indices = vertices[vertices < len(pts_np)] + result.append(pts_np[newton_polytope_indices, :].tolist()) + + result = get_newton_polytope_approx_lst(result, inplace=False, get_ended=False) + + if inplace: + points[:, :, :] = result + else: + return result -def shift_lst(points: List[List[List[int]]], coords: List[List[int]], axis: List[int], inplace=True): +def shift_lst(points: List[List[List[float]]], coords: List[List[int]], axis: List[int], inplace=True): """ Shift a set of points according to the rule of Hironaka game. """ @@ -70,21 +94,21 @@ def shift_lst(points: List[List[List[int]]], coords: List[List[int]], axis: List if inplace: for b in range(batch_num): - if axis[b] is None: + if axis[b] not in coords[b]: continue for i in range(len(points[b])): points[b][i][axis[b]] = sum([points[b][i][k] for k in coords[b]]) else: result = [[ [ - sum([x[k] for k in coord]) if ax is not None and i == ax else x[i] + sum([x[k] for k in coord]) if ax in coord and i == ax else x[i] for i in range(dim) ] for x in point ] for point, coord, ax in zip(points, coords, axis)] return result -def reposition_lst(points: List[List[List[int]]], inplace=True): +def reposition_lst(points: List[List[List[float]]], inplace=True): """ Reposition all batches of points so that each of them hits all coordinate planes. """ diff --git a/hironaka/src/_snippets.py b/hironaka/src/_snippets.py index 9d0ad87..bccc07e 100644 --- a/hironaka/src/_snippets.py +++ b/hironaka/src/_snippets.py @@ -1,6 +1,7 @@ from typing import List, Union import numpy as np +import numbers def get_shape(o): @@ -12,17 +13,27 @@ def get_shape(o): If the nested list/tuple objects are not of uniform shape, this function becomes pointless. Therefore, being uniform is an assumption before using this snippet. - Anti-example: o = [ ([1,2,3],2,3),(2,3,4) ], it will return (2,3,3). - It's intuitively wrong but this function is not responsible for checking the uniformity. + It is intuitively wrong but this function is not responsible for checking the uniformity. + + For the last axis, it also removes ONE non-number entries at the end of o[0]...[0]. + - Example: o = [[1,2,3,'d'],[2,3,4]], it will return (2,3) + - Anti-example: o = [[1,2,3,'d','d'],[2,3,4]], it will return (2,4) + It is again intuitively tricky, but it is only designed to allow for one non-number mark + and sanity check is not our responsibility. Also, if it hits a length-0 object, it will just stop. """ unwrapped = o shape = [] + last = None while isinstance(unwrapped, (list, tuple)) and unwrapped: shape.append(len(unwrapped)) + last = unwrapped[-1] unwrapped = unwrapped[0] + if not isinstance(last, numbers.Number): + shape[-1] -= 1 return tuple(shape) @@ -155,3 +166,11 @@ def mask_encoded_action(dimension: int): result[1 << i] = 0 return result + + +def generate_points(n: int, dimension=3, max_value=50): + return [[np.random.randint(max_value) for _ in range(dimension)] for _ in range(n)] + + +def generate_batch_points(n: int, batch_num=1, dimension=3, max_value=50): + return [[[np.random.randint(max_value) for _ in range(dimension)] for _ in range(n)] for _ in range(batch_num)] diff --git a/hironaka/src/_torch_ops.py b/hironaka/src/_torch_ops.py index 472d8d8..5a3d798 100644 --- a/hironaka/src/_torch_ops.py +++ b/hironaka/src/_torch_ops.py @@ -5,7 +5,9 @@ from hironaka.src import batched_coord_list_to_binary -def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[bool] = True): +def get_newton_polytope_approx_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): assert len(points.shape) == 3 batch_size, max_num_points, dimension = points.shape @@ -27,21 +29,26 @@ def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[boo points_to_remove = ((difference >= 0) & diag_filter & filter_matrix).all(3).any(2) points_to_remove = points_to_remove.unsqueeze(2).repeat(1, 1, dimension) + r = points * ~points_to_remove + torch.full(points.shape, padding_value) * points_to_remove + if inplace: - points[:, :, :] = points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove + points[:, :, :] = r return None else: - return points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove + return r -def get_newton_polytope_torch(points: torch.Tensor, inplace: Optional[bool] = True): - return get_newton_polytope_approx_torch(points, inplace=inplace) +def get_newton_polytope_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): + return get_newton_polytope_approx_torch(points, inplace=inplace, padding_value=padding_value) def shift_torch(points: torch.Tensor, coord: Union[torch.Tensor, List[List[int]]], axis: Union[torch.Tensor, List[int]], - inplace=True): + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): """ note: If coord is a list, it is assumed to be lists of chosen coordinates. @@ -85,25 +92,37 @@ def shift_torch(points: torch.Tensor, trans_matrix = trans_matrix.unsqueeze(1).repeat(1, max_num_points, 1, 1) transformed_points = torch.matmul(trans_matrix, points.unsqueeze(3)).squeeze(3) - result = (transformed_points * available_points) + torch.full(points.shape, -1.0) * ~available_points + r = (transformed_points * available_points) + torch.full(points.shape, padding_value) * ~available_points if inplace: - points[:, :, :] = result + points[:, :, :] = r return None else: - return result + return r -def reposition_torch(points: torch.Tensor, inplace: Optional[bool] = True): +def reposition_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): available_points = points.ge(0) maximum = torch.max(points) preprocessed = points * available_points + torch.full(points.shape, maximum + 1) * ~available_points coordinate_minimum = torch.amin(preprocessed, 1) unfiltered_result = points - coordinate_minimum.unsqueeze(1).repeat(1, points.shape[1], 1) - result = unfiltered_result * available_points + torch.full(points.shape, -1.0) * ~available_points + r = unfiltered_result * available_points + torch.full(points.shape, padding_value) * ~available_points if inplace: - points[:, :, :] = result + points[:, :, :] = r return None else: - return result + return r + + +def rescale_torch(points: torch.Tensor, inplace: Optional[bool] = True, padding_value: Optional[float] = -1.): + available_points = points.ge(0) + r = points * available_points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1)) + \ + padding_value * ~available_points + if inplace: + points[:, :, :] = r + else: + return r \ No newline at end of file diff --git a/hironaka/util/README.md b/hironaka/util/README.md new file mode 100644 index 0000000..f49e004 --- /dev/null +++ b/hironaka/util/README.md @@ -0,0 +1,2 @@ +# hironaka.util +This is a collection of helper functions that do not impact the core functionality. \ No newline at end of file diff --git a/hironaka/util/__init__.py b/hironaka/util/__init__.py index 6331a7f..114d7ee 100644 --- a/hironaka/util/__init__.py +++ b/hironaka/util/__init__.py @@ -1,2 +1 @@ -from .geom import * from .search import * diff --git a/hironaka/util/geom.py b/hironaka/util/geom.py deleted file mode 100644 index 051bbb7..0000000 --- a/hironaka/util/geom.py +++ /dev/null @@ -1,9 +0,0 @@ -import numpy as np - - -def generate_points(n: int, dimension=3, max_value=50): - return [[np.random.randint(max_value) for _ in range(dimension)] for _ in range(n)] - - -def generate_batch_points(n: int, batch_num=1, dimension=3, max_value=50): - return [[[np.random.randint(max_value) for _ in range(dimension)] for _ in range(n)] for _ in range(batch_num)] diff --git a/hironaka/util/search.py b/hironaka/util/search.py index b8605d3..df952ad 100644 --- a/hironaka/util/search.py +++ b/hironaka/util/search.py @@ -1,6 +1,6 @@ from collections import deque -from hironaka.abs import Points +from hironaka.core import Points from hironaka.host import Host @@ -20,8 +20,6 @@ def search_depth(points: Points, host: Host, debug=False): max_depth = max(max_depth, depth) coords = host.select_coord(current, debug=debug) - # print(current, depth) - for i in coords[0]: nxt = current.copy() nxt.shift(coords, [i]) diff --git a/hironaka/validator/HironakaValidator.py b/hironaka/validator/HironakaValidator.py index ef7e7f7..ef36447 100644 --- a/hironaka/validator/HironakaValidator.py +++ b/hironaka/validator/HironakaValidator.py @@ -1,9 +1,9 @@ from typing import Optional, Dict, Any import logging -from hironaka.abs import Points +from hironaka.core import Points from hironaka.gameHironaka import GameHironaka -from hironaka.util import generate_batch_points +from hironaka.src import generate_batch_points class HironakaValidator(GameHironaka): diff --git a/hironaka/validator/README.md b/hironaka/validator/README.md new file mode 100644 index 0000000..28693c1 --- /dev/null +++ b/hironaka/validator/README.md @@ -0,0 +1,2 @@ +# hironaka.validator +This contains validators that take in agents and hosts to evaluate them. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3740d6a..1edb69b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ numpy==1.22.4 treelib==1.6.1 gym==0.22 -torch~=1.12.0 \ No newline at end of file +torch~=1.12.0 +scipy~=1.8.1 +PyYAML~=6.0 \ No newline at end of file diff --git a/test/testGame.py b/test/testGame.py index ddb44f8..179289f 100644 --- a/test/testGame.py +++ b/test/testGame.py @@ -3,6 +3,7 @@ from hironaka.agent import RandomAgent from hironaka.gameHironaka import GameHironaka from hironaka.host import Zeillinger +from hironaka.src import generate_points, generate_batch_points from hironaka.util import * diff --git a/test/testGymEnv.py b/test/testGymEnv.py index 7d9b12a..19000bb 100644 --- a/test/testGymEnv.py +++ b/test/testGymEnv.py @@ -4,19 +4,19 @@ import numpy as np from gym.envs.registration import register -from hironaka.abs import Points +from hironaka.core import Points from hironaka.agent import RandomAgent from hironaka.host import Zeillinger, RandomHost register( id='hironaka/HironakaHost-v0', - entry_point='hironaka.envs:HironakaHostEnv', + entry_point='hironaka.gym_env:HironakaHostEnv', max_episode_steps=10000, ) register( id='hironaka/HironakaAgent-v0', - entry_point='hironaka.envs:HironakaAgentEnv', + entry_point='hironaka.gym_env:HironakaAgentEnv', max_episode_steps=10000, ) diff --git a/test/testPoints.py b/test/testPoints.py index 7fb2f59..0056d25 100644 --- a/test/testPoints.py +++ b/test/testPoints.py @@ -2,9 +2,8 @@ import numpy as np -from hironaka.abs.Points import Points -from hironaka.src import make_nested_list -from hironaka.util import generate_points +from hironaka.core.Points import Points +from hironaka.src import make_nested_list, generate_points class TestPoints(unittest.TestCase): @@ -63,6 +62,7 @@ def test_operations2(self): [[[16, 11, 5, 6], [11, 12, 18, 6], [11, 11, 1, 19], [8, 3, 17, 8], [7, 5, 3, 8]]] )) q.get_newton_polytope() + assert str(q) == str(p.get_newton_polytope(inplace=False)) assert str(q) == str(r2) @@ -138,3 +138,25 @@ def test_reposition(self): a = points.reposition(inplace=False) assert str(points) == str(r) assert str(points) == str(a) + + def test_distinguished_elements(self): + points = Points(make_nested_list( + [(7, 5, 3, 8), (8, 1, 8, 18), (8, 3, 17, 8), + (11, 11, 1, 19), (11, 12, 18, 6), (16, 11, 5, 6)] + ), distinguished_points=[2]) + + points.get_newton_polytope() + d_ind = points.distinguished_points[0] + assert tuple(points.points[0][d_ind]) == (8, 3, 17, 8) + + points.shift([[0, 1]], [0]) + points.get_newton_polytope() + d_ind = points.distinguished_points[0] + assert tuple(points.points[0][d_ind]) == (11, 3, 17, 8) + + points.shift([[0, 2]], [0]) + points.shift([[2, 3]], [2]) + points.shift([[0, 1]], [1]) + points.get_newton_polytope() + d_ind = points.distinguished_points[0] + assert d_ind is None diff --git a/test/testPolicy.py b/test/testPolicy.py index 7e9ecd3..0aae418 100644 --- a/test/testPolicy.py +++ b/test/testPolicy.py @@ -2,12 +2,12 @@ from torch import nn -from hironaka.abs import Points +from hironaka.core import Points from hironaka.gameHironaka import GameHironaka from hironaka.policy.NNPolicy import NNPolicy -from hironaka.policy_players.PolicyAgent import PolicyAgent -from hironaka.policy_players.PolicyHost import PolicyHost -from hironaka.util import generate_batch_points +from hironaka.agent import PolicyAgent +from hironaka.host import PolicyHost +from hironaka.src import generate_batch_points class NN(nn.Module): diff --git a/test/testSearch.py b/test/testSearch.py index f9004ba..f3ff050 100644 --- a/test/testSearch.py +++ b/test/testSearch.py @@ -2,7 +2,7 @@ from treelib import Tree -from hironaka.abs import Points +from hironaka.core import Points from hironaka.host import Zeillinger from hironaka.src import make_nested_list from hironaka.util import search_depth, search_tree diff --git a/test/testTorchPoints.py b/test/testTorchPoints.py index 93cf629..0f49261 100644 --- a/test/testTorchPoints.py +++ b/test/testTorchPoints.py @@ -2,8 +2,8 @@ import torch -from hironaka.abs.PointsTensor import PointsTensor -from hironaka.src._torch_ops import get_newton_polytope_torch, shift_torch, reposition_torch +from hironaka.core.PointsTensor import PointsTensor +from hironaka.src import get_newton_polytope_torch, shift_torch, reposition_torch class testTorchPoints(unittest.TestCase): @@ -37,6 +37,16 @@ class testTorchPoints(unittest.TestCase): [-1., -1., -1., -1.], [-1., -1., -1., -1.]]]) + rs = torch.Tensor([[[ 0.0000, 0.2000, 0.1000, 0.1000], + [-1.0000, -1.0000, -1.0000, -1.0000], + [ 0.3000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 1.0000, 0.5000, 0.0000]], + + [[ 0.0000, 0.0000, 0.4000, 1.0000], + [ 0.2000, 0.0000, 0.0000, 0.0000], + [-1.0000, -1.0000, -1.0000, -1.0000], + [-1.0000, -1.0000, -1.0000, -1.0000]]]) + def test_functions(self): p = torch.FloatTensor( [ @@ -44,7 +54,6 @@ def test_functions(self): [[0, 1, 3, 5], [1, 1, 1, 1], [9, 8, 2, 1], [-1, -1, -1, -1]] ] ) - assert torch.all(get_newton_polytope_torch(p, inplace=False).eq(self.r)) get_newton_polytope_torch(p, inplace=True) assert torch.all(p.eq(self.r)) @@ -72,3 +81,5 @@ def test_pointstensor(self): assert str(pts) == str(self.r2) pts.reposition() assert str(pts) == str(self.r3) + pts.rescale() + assert str(pts) == str(self.rs) \ No newline at end of file diff --git a/test/testUtil.py b/test/testUtil.py index 52d7a97..6023063 100644 --- a/test/testUtil.py +++ b/test/testUtil.py @@ -2,7 +2,7 @@ import numpy as np -from hironaka.src import get_batched_padded_array, batched_coord_list_to_binary +from hironaka.src import get_batched_padded_array, batched_coord_list_to_binary, get_newton_polytope_lst, get_shape class TestUtil(unittest.TestCase): @@ -24,3 +24,53 @@ def test_batched_coord_to_bin(self): coords = [[1, 2, 3], [2, 0, 1]] r = np.array([[0, 1, 1, 1], [1, 1, 1, 0]]) assert (batched_coord_list_to_binary(coords, 4) == r).all() + + def test_true_newton_polytope(self): + p = [[[1., 0.], [0.9, 0.9], [0., 1.]]] + r = [[[1., 0.], [0., 1.]]] + + assert str(get_newton_polytope_lst(p, inplace=False)) == str(r) + + p = [[[0.37807224, 0.60967653, 0.50641324]]] + assert str(get_newton_polytope_lst(p, inplace=False)) == str(p) + + p = [ + [[0.11675344, 0.39038985, 0.55826897, 0.06529552], + [0.9846373, 0.45638349, 0.70517085, 0.90032522], + [0.01027646, 0.11461289, 0.89243383, 0.634063], + [0.58811481, 0.99114348, 0.61889408, 0.59967777], + [0.91356043, 0.62654142, 0.69501398, 0.68474988], + [0.88135114, 0.30110585, 0.04229966, 0.03769748], + [0.37982495, 0.17156216, 0.33440668, 0.48339728], + [0.12123305, 0.15986878, 0.11907919, 0.59999993], + [0.9496461, 0.16063278, 0.42188375, 0.66339718], + [0.59075721, 0.17488182, 0.89326396, 0.01449242]], + [[0.64929492, 0.8896327, 0.98860123, 0.52941554], + [0.25994605, 0.03554693, 0.43534583, 0.19954576], + [0.62238657, 0.33769715, 0.2672676, 0.67115147], + [0.23643443, 0.51686672, 0.72861238, 0.0351913], + [0.3788386, 0.67130138, 0.87033132, 0.4363841], + [0.30030881, 0.11823987, 0.20820786, 0.49078142], + [0.25722259, 0.32548102, 0.97916295, 0.0842389], + [0.06561767, 0.55689435, 0.70502167, 0.27102844], + [0.38096357, 0.59775385, 0.97628977, 0.60265799], + [0.28909349, 0.08945314, 0.80995294, 0.63317]] + ] + + r = [ + [[0.88135114, 0.30110585, 0.04229966, 0.03769748], + [0.59075721, 0.17488182, 0.89326396, 0.01449242], + [0.12123305, 0.15986878, 0.11907919, 0.59999993], + [0.11675344, 0.39038985, 0.55826897, 0.06529552], + [0.01027646, 0.11461289, 0.89243383, 0.634063]], + [[0.30030881, 0.11823987, 0.20820786, 0.49078142], + [0.25994605, 0.03554693, 0.43534583, 0.19954576], + [0.25722259, 0.32548102, 0.97916295, 0.0842389], + [0.23643443, 0.51686672, 0.72861238, 0.0351913], + [0.06561767, 0.55689435, 0.70502167, 0.27102844]]] + + assert str(get_newton_polytope_lst(p, inplace=False)) == str(r) + + def test_get_shape_extra_character(self): + p = [[1, 2, 3, 'd'], [2, 3, 4]] + assert get_shape(p) == (2,3) \ No newline at end of file diff --git a/test/testZeillinger.py b/test/testZeillinger.py index 73075a7..3347752 100644 --- a/test/testZeillinger.py +++ b/test/testZeillinger.py @@ -1,6 +1,6 @@ import unittest -from hironaka.abs import Points +from hironaka.core import Points from hironaka.host import Zeillinger from hironaka.src import make_nested_list diff --git a/train/config.yml b/train/config.yml new file mode 100644 index 0000000..24f7e5e --- /dev/null +++ b/train/config.yml @@ -0,0 +1,22 @@ +global: + dimension: 3 + max_value: 20 + masked: true + max_number_points: 20 + use_cuda: true + normalized: false + value_threshold: 100000000 + step_threshold: 200 # HironakaAgentEnv, Validator only + fixed_penalty_crossing_threshold: 0 + stop_at_threshold: true # HironakaAgentEnv, Validator only + improve_efficiency: true + scale_observation: true + reward_based_on_point_reduction: true + use_discrete_actions_for_host: true +training: + epoch: 1 + batch_size: 32 + save_frequency: 1 # save model every {save_frequency} epochs. + total_timestep: 10000 +models: + version_string: 'vanilla_v0' diff --git a/train/train_sb3.py b/train/train_sb3.py new file mode 100644 index 0000000..899579e --- /dev/null +++ b/train/train_sb3.py @@ -0,0 +1,115 @@ +import logging +import os +import pathlib +import argparse +import sys + +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.resolve())) + +from hironaka.policy import NNPolicy +from hironaka.validator import HironakaValidator + +import gym +from gym.envs.registration import register +import yaml + +from stable_baselines3 import DQN + +from hironaka.agent import RandomAgent, ChooseFirstAgent, PolicyAgent +from hironaka.host import Zeillinger, RandomHost, PolicyHost + +register( + id='hironaka/HironakaHost-v0', + entry_point='hironaka.gym_env:HironakaHostEnv', + max_episode_steps=10000, +) + +register( + id='hironaka/HironakaAgent-v0', + entry_point='hironaka.gym_env:HironakaAgentEnv', + max_episode_steps=10000, +) + +sb3_policy_config = { + "net_arch": [256] * 32, + "normalize_images": True} + + +def main(config_file: str): + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + if not logger.hasHandlers(): + logger.addHandler(logging.StreamHandler(sys.stdout)) + + model_path = 'models' + if config_file is None: + config_file = 'train/config.yml' + if not os.path.exists(model_path): + logger.info("Created 'models/'.") + os.makedirs(model_path) + else: + logger.warning("Model folder 'models/' already exists.") + + with open(config_file, "r") as stream: + config = yaml.safe_load(stream) # Generate the config as a dict object + + training_config = config['global'] + + epoch = config['training']['epoch'] + batch_size = config['training']['batch_size'] + save_frequency = config['training']['save_frequency'] + total_timestep = config['training']['total_timestep'] + + version_string = config['models']['version_string'] + + env_h = gym.make("hironaka/HironakaHost-v0", host=Zeillinger(), config_kwargs=training_config) + + for i in range(epoch): + model_a = DQN("MultiInputPolicy", env_h, verbose=0, policy_kwargs=sb3_policy_config, batch_size=batch_size) + model_a.learn(total_timesteps=total_timestep) + + p_a = NNPolicy(model_a.q_net.q_net, mode='agent', eval_mode=True, config_kwargs=training_config) + nnagent = PolicyAgent(p_a) + env_a = gym.make("hironaka/HironakaAgent-v0", agent=nnagent, config_kwargs=training_config) + + model_h = DQN("MlpPolicy", env_a, verbose=0, policy_kwargs=sb3_policy_config, batch_size=batch_size, gamma=1) + model_h.learn(total_timesteps=total_timestep) + + p_h = NNPolicy(model_h.q_net.q_net, mode='host', eval_mode=True, config_kwargs=training_config) + nnhost = PolicyHost(p_h, **training_config) + env_h = gym.make("hironaka/HironakaHost-v0", host=nnhost, config_kwargs=training_config) + + # Validation + + if i % save_frequency == 0: + print(f"Epoch {i*5}") + print("agent validation:") + agents = [nnagent, RandomAgent(), ChooseFirstAgent()] + # agents = [] + for agent in agents: + validator = HironakaValidator(Zeillinger(), agent, config_kwargs=config) + result = validator.playoff(1000) + print(str(type(agent)).split("'")[-2].split(".")[-1]) + print(f" - number of games:{len(result)}") + print(f"host validation:") + hosts = [nnhost, RandomHost(), Zeillinger()] + for host in hosts: + validator = HironakaValidator(host, nnagent, config_kwargs=config) + result = validator.playoff(1000) + + print(str(type(host)).split("'")[-2].split(".")[-1]) + print(f" - number of games:{len(result)}") + + # Save model + if i % save_frequency == 0: + model_a.save(f"{model_path}/{version_string}_epoch_{i}_agent") + model_h.save(f"{model_path}/{version_string}_epoch_{i}_host") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="train the host and agent.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-c", "--config_file", help="Specify config file location.") + args = parser.parse_args() + config_args = vars(args) + main(**config_args)