From 5c3e2aee474b5e80979ac014da901e262138afad Mon Sep 17 00:00:00 2001 From: Void Date: Fri, 15 Jul 2022 19:02:05 -0400 Subject: [PATCH] refining and fixing PointsTensor which pave the way for more efficient training codes. --- hironaka/agent.py | 20 ++++++++++++ hironaka/core/PointsTensor.py | 24 +++++++------- hironaka/cpp/README.md | 2 +- hironaka/host.py | 20 ++++++++++++ hironaka/policy_players/PolicyAgent.py | 23 ------------- hironaka/policy_players/PolicyHost.py | 25 -------------- hironaka/policy_players/README.md | 2 -- hironaka/policy_players/__init__.py | 2 -- hironaka/src/_torch_ops.py | 45 ++++++++++++++++++-------- hironaka/util/__init__.py | 1 - hironaka/util/geom.py | 0 test/testPolicy.py | 4 +-- test/testTorchPoints.py | 13 +++++++- train/train_sb3.py | 5 ++- 14 files changed, 102 insertions(+), 84 deletions(-) delete mode 100644 hironaka/policy_players/PolicyAgent.py delete mode 100644 hironaka/policy_players/PolicyHost.py delete mode 100644 hironaka/policy_players/README.md delete mode 100644 hironaka/policy_players/__init__.py delete mode 100644 hironaka/util/geom.py diff --git a/hironaka/agent.py b/hironaka/agent.py index 526ceed..70c944b 100644 --- a/hironaka/agent.py +++ b/hironaka/agent.py @@ -1,8 +1,10 @@ import abc +from typing import List import numpy as np from .core import Points +from .policy import Policy class Agent(abc.ABC): @@ -33,3 +35,21 @@ def move(self, points: Points, coords, inplace=True): points.shift(coords, actions) points.get_newton_polytope() return actions + + +class PolicyAgent(Agent): + def __init__(self, policy: Policy): + self._policy = policy + + def move(self, points: Points, coords: List[List[int]], inplace=True): + assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks? + + features = points.get_features() + + actions = self._policy.predict((features, coords)) + + if inplace: + points.shift(coords, actions) + points.get_newton_polytope() + + return actions diff --git a/hironaka/core/PointsTensor.py b/hironaka/core/PointsTensor.py index abcaafa..371cfcb 100644 --- a/hironaka/core/PointsTensor.py +++ b/hironaka/core/PointsTensor.py @@ -4,41 +4,43 @@ import torch from hironaka.core.PointsBase import PointsBase -from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, get_batched_padded_array +from hironaka.src import get_batched_padded_array, rescale_torch from hironaka.src import shift_torch, get_newton_polytope_torch, reposition_torch class PointsTensor(PointsBase): - subcls_config_keys = ['value_threshold', 'device_key', 'padded_value'] + subcls_config_keys = ['value_threshold', 'device_key', 'padding_value'] copied_attributes = ['distinguished_points'] def __init__(self, points: Union[torch.Tensor, List[List[List[int]]], np.ndarray], value_threshold: Optional[int] = 1e8, device_key: Optional[str] = 'cpu', - padded_value: Optional[float] = -1.0, + padding_value: Optional[float] = -1.0, distinguished_points: Optional[List[int]] = None, config_kwargs: Optional[Dict[str, Any]] = None, **kwargs): config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs} self.value_threshold = value_threshold - assert padded_value < 0, f"'padded_value' must be a negative number. Got {padded_value} instead." + assert padding_value < 0, f"'padding_value' must be a negative number. Got {padding_value} instead." if isinstance(points, list): points = torch.FloatTensor( get_batched_padded_array(points, new_length=config['max_num_points'], - constant_value=padded_value)) - elif isinstance(points, (torch.Tensor, np.ndarray)): + constant_value=padding_value)) + elif isinstance(points, np.ndarray): points = torch.FloatTensor(points) + elif isinstance(points, torch.Tensor): + points = points.type(torch.float32) else: raise Exception(f"Input must be a Tensor, a numpy array or a nested list. Got {type(points)}.") self.batch_size, self.max_num_points, self.dimension = points.shape self.device_key = device_key - self.padded_value = padded_value + self.padding_value = padding_value self.distinguished_points = distinguished_points super().__init__(points, **config) @@ -65,19 +67,19 @@ def _shift(self, coords: List[List[int]], axis: List[int], inplace: Optional[bool] = True): - return shift_torch(points, coords, axis, inplace=inplace) + return shift_torch(points, coords, axis, inplace=inplace, padding_value=self.padding_value) def _get_newton_polytope(self, points: torch.Tensor, inplace: Optional[bool] = True): - return get_newton_polytope_torch(points, inplace=inplace) + return get_newton_polytope_torch(points, inplace=inplace, padding_value=self.padding_value) def _get_shape(self, points: torch.Tensor): return points.shape def _reposition(self, points: torch.Tensor, inplace: Optional[bool] = True): - return reposition_torch(points, inplace=inplace) + return reposition_torch(points, inplace=inplace, padding_value=self.padding_value) def _rescale(self, points: torch.Tensor, inplace: Optional[bool] = True): - return points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1)) + r = rescale_torch(points, inplace=inplace, padding_value=self.padding_value) def _points_copy(self, points: torch.Tensor): return points.clone().detach() diff --git a/hironaka/cpp/README.md b/hironaka/cpp/README.md index 6ef426c..d271cba 100644 --- a/hironaka/cpp/README.md +++ b/hironaka/cpp/README.md @@ -1,3 +1,3 @@ -# C++ modules +# C++ modules (depreciated) This was originally written for numpy operations but is currently not in use. We leave it here for potential future revival. \ No newline at end of file diff --git a/hironaka/host.py b/hironaka/host.py index 0b1d975..08d5c91 100644 --- a/hironaka/host.py +++ b/hironaka/host.py @@ -1,9 +1,11 @@ import abc from itertools import combinations +from typing import Optional import numpy as np from .core import Points +from .policy import Policy class Host(abc.ABC): @@ -61,3 +63,21 @@ def select_coord(self, points: Points, debug=False): result.append([0, 1]) return result + + +class PolicyHost(Host): + def __init__(self, + policy: Policy, + use_discrete_actions_for_host: Optional[bool] = False, + **kwargs): + self._policy = policy + self.use_discrete_actions_for_host = kwargs.get('use_discrete_actions_for_host', use_discrete_actions_for_host) + + def select_coord(self, points: Points, debug=False): + features = points.get_features() + + coords = self._policy.predict(features) # return multi-binary array + result = [] + for b in range(coords.shape[0]): + result.append(np.where(coords[b] == 1)[0]) + return result diff --git a/hironaka/policy_players/PolicyAgent.py b/hironaka/policy_players/PolicyAgent.py deleted file mode 100644 index af9dac2..0000000 --- a/hironaka/policy_players/PolicyAgent.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import List - -from hironaka.core import Points -from hironaka.agent import Agent -from hironaka.policy.Policy import Policy - - -class PolicyAgent(Agent): - def __init__(self, policy: Policy): - self._policy = policy - - def move(self, points: Points, coords: List[List[int]], inplace=True): - assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks? - - features = points.get_features() - - actions = self._policy.predict((features, coords)) - - if inplace: - points.shift(coords, actions) - points.get_newton_polytope() - - return actions diff --git a/hironaka/policy_players/PolicyHost.py b/hironaka/policy_players/PolicyHost.py deleted file mode 100644 index 968537c..0000000 --- a/hironaka/policy_players/PolicyHost.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Optional - -import numpy as np - -from hironaka.core import Points -from hironaka.host import Host -from hironaka.policy.Policy import Policy - - -class PolicyHost(Host): - def __init__(self, - policy: Policy, - use_discrete_actions_for_host: Optional[bool] = False, - **kwargs): - self._policy = policy - self.use_discrete_actions_for_host = kwargs.get('use_discrete_actions_for_host', use_discrete_actions_for_host) - - def select_coord(self, points: Points, debug=False): - features = points.get_features() - - coords = self._policy.predict(features) # return multi-binary array - result = [] - for b in range(coords.shape[0]): - result.append(np.where(coords[b] == 1)[0]) - return result diff --git a/hironaka/policy_players/README.md b/hironaka/policy_players/README.md deleted file mode 100644 index fe688cc..0000000 --- a/hironaka/policy_players/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# hironaka.policy_players -This submodule contains players (host, agent) that take actions based on `Policy`. \ No newline at end of file diff --git a/hironaka/policy_players/__init__.py b/hironaka/policy_players/__init__.py deleted file mode 100644 index d9a06dc..0000000 --- a/hironaka/policy_players/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .PolicyHost import PolicyHost -from .PolicyAgent import PolicyAgent diff --git a/hironaka/src/_torch_ops.py b/hironaka/src/_torch_ops.py index 472d8d8..5a3d798 100644 --- a/hironaka/src/_torch_ops.py +++ b/hironaka/src/_torch_ops.py @@ -5,7 +5,9 @@ from hironaka.src import batched_coord_list_to_binary -def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[bool] = True): +def get_newton_polytope_approx_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): assert len(points.shape) == 3 batch_size, max_num_points, dimension = points.shape @@ -27,21 +29,26 @@ def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[boo points_to_remove = ((difference >= 0) & diag_filter & filter_matrix).all(3).any(2) points_to_remove = points_to_remove.unsqueeze(2).repeat(1, 1, dimension) + r = points * ~points_to_remove + torch.full(points.shape, padding_value) * points_to_remove + if inplace: - points[:, :, :] = points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove + points[:, :, :] = r return None else: - return points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove + return r -def get_newton_polytope_torch(points: torch.Tensor, inplace: Optional[bool] = True): - return get_newton_polytope_approx_torch(points, inplace=inplace) +def get_newton_polytope_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): + return get_newton_polytope_approx_torch(points, inplace=inplace, padding_value=padding_value) def shift_torch(points: torch.Tensor, coord: Union[torch.Tensor, List[List[int]]], axis: Union[torch.Tensor, List[int]], - inplace=True): + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): """ note: If coord is a list, it is assumed to be lists of chosen coordinates. @@ -85,25 +92,37 @@ def shift_torch(points: torch.Tensor, trans_matrix = trans_matrix.unsqueeze(1).repeat(1, max_num_points, 1, 1) transformed_points = torch.matmul(trans_matrix, points.unsqueeze(3)).squeeze(3) - result = (transformed_points * available_points) + torch.full(points.shape, -1.0) * ~available_points + r = (transformed_points * available_points) + torch.full(points.shape, padding_value) * ~available_points if inplace: - points[:, :, :] = result + points[:, :, :] = r return None else: - return result + return r -def reposition_torch(points: torch.Tensor, inplace: Optional[bool] = True): +def reposition_torch(points: torch.Tensor, + inplace: Optional[bool] = True, + padding_value: Optional[float] = -1.): available_points = points.ge(0) maximum = torch.max(points) preprocessed = points * available_points + torch.full(points.shape, maximum + 1) * ~available_points coordinate_minimum = torch.amin(preprocessed, 1) unfiltered_result = points - coordinate_minimum.unsqueeze(1).repeat(1, points.shape[1], 1) - result = unfiltered_result * available_points + torch.full(points.shape, -1.0) * ~available_points + r = unfiltered_result * available_points + torch.full(points.shape, padding_value) * ~available_points if inplace: - points[:, :, :] = result + points[:, :, :] = r return None else: - return result + return r + + +def rescale_torch(points: torch.Tensor, inplace: Optional[bool] = True, padding_value: Optional[float] = -1.): + available_points = points.ge(0) + r = points * available_points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1)) + \ + padding_value * ~available_points + if inplace: + points[:, :, :] = r + else: + return r \ No newline at end of file diff --git a/hironaka/util/__init__.py b/hironaka/util/__init__.py index 6331a7f..114d7ee 100644 --- a/hironaka/util/__init__.py +++ b/hironaka/util/__init__.py @@ -1,2 +1 @@ -from .geom import * from .search import * diff --git a/hironaka/util/geom.py b/hironaka/util/geom.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/testPolicy.py b/test/testPolicy.py index 22328f4..0aae418 100644 --- a/test/testPolicy.py +++ b/test/testPolicy.py @@ -5,8 +5,8 @@ from hironaka.core import Points from hironaka.gameHironaka import GameHironaka from hironaka.policy.NNPolicy import NNPolicy -from hironaka.policy_players.PolicyAgent import PolicyAgent -from hironaka.policy_players.PolicyHost import PolicyHost +from hironaka.agent import PolicyAgent +from hironaka.host import PolicyHost from hironaka.src import generate_batch_points diff --git a/test/testTorchPoints.py b/test/testTorchPoints.py index 0cd9f08..0f49261 100644 --- a/test/testTorchPoints.py +++ b/test/testTorchPoints.py @@ -37,6 +37,16 @@ class testTorchPoints(unittest.TestCase): [-1., -1., -1., -1.], [-1., -1., -1., -1.]]]) + rs = torch.Tensor([[[ 0.0000, 0.2000, 0.1000, 0.1000], + [-1.0000, -1.0000, -1.0000, -1.0000], + [ 0.3000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 1.0000, 0.5000, 0.0000]], + + [[ 0.0000, 0.0000, 0.4000, 1.0000], + [ 0.2000, 0.0000, 0.0000, 0.0000], + [-1.0000, -1.0000, -1.0000, -1.0000], + [-1.0000, -1.0000, -1.0000, -1.0000]]]) + def test_functions(self): p = torch.FloatTensor( [ @@ -44,7 +54,6 @@ def test_functions(self): [[0, 1, 3, 5], [1, 1, 1, 1], [9, 8, 2, 1], [-1, -1, -1, -1]] ] ) - assert torch.all(get_newton_polytope_torch(p, inplace=False).eq(self.r)) get_newton_polytope_torch(p, inplace=True) assert torch.all(p.eq(self.r)) @@ -72,3 +81,5 @@ def test_pointstensor(self): assert str(pts) == str(self.r2) pts.reposition() assert str(pts) == str(self.r3) + pts.rescale() + assert str(pts) == str(self.rs) \ No newline at end of file diff --git a/train/train_sb3.py b/train/train_sb3.py index f889102..899579e 100644 --- a/train/train_sb3.py +++ b/train/train_sb3.py @@ -7,7 +7,6 @@ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.resolve())) from hironaka.policy import NNPolicy -from hironaka.policy_players import PolicyHost, PolicyAgent from hironaka.validator import HironakaValidator import gym @@ -16,8 +15,8 @@ from stable_baselines3 import DQN -from hironaka.agent import RandomAgent, ChooseFirstAgent -from hironaka.host import Zeillinger, RandomHost +from hironaka.agent import RandomAgent, ChooseFirstAgent, PolicyAgent +from hironaka.host import Zeillinger, RandomHost, PolicyHost register( id='hironaka/HironakaHost-v0',