Skip to content

Commit

Permalink
refining and fixing PointsTensor which pave the way for more efficien…
Browse files Browse the repository at this point in the history
…t training codes.
  • Loading branch information
honglu2875 committed Jul 15, 2022
1 parent 8f85135 commit 5c3e2ae
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 84 deletions.
20 changes: 20 additions & 0 deletions hironaka/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
from typing import List

import numpy as np

from .core import Points
from .policy import Policy


class Agent(abc.ABC):
Expand Down Expand Up @@ -33,3 +35,21 @@ def move(self, points: Points, coords, inplace=True):
points.shift(coords, actions)
points.get_newton_polytope()
return actions


class PolicyAgent(Agent):
def __init__(self, policy: Policy):
self._policy = policy

def move(self, points: Points, coords: List[List[int]], inplace=True):
assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks?

features = points.get_features()

actions = self._policy.predict((features, coords))

if inplace:
points.shift(coords, actions)
points.get_newton_polytope()

return actions
24 changes: 13 additions & 11 deletions hironaka/core/PointsTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,43 @@
import torch

from hironaka.core.PointsBase import PointsBase
from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, get_batched_padded_array
from hironaka.src import get_batched_padded_array, rescale_torch
from hironaka.src import shift_torch, get_newton_polytope_torch, reposition_torch


class PointsTensor(PointsBase):
subcls_config_keys = ['value_threshold', 'device_key', 'padded_value']
subcls_config_keys = ['value_threshold', 'device_key', 'padding_value']
copied_attributes = ['distinguished_points']

def __init__(self,
points: Union[torch.Tensor, List[List[List[int]]], np.ndarray],
value_threshold: Optional[int] = 1e8,
device_key: Optional[str] = 'cpu',
padded_value: Optional[float] = -1.0,
padding_value: Optional[float] = -1.0,
distinguished_points: Optional[List[int]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
**kwargs):
config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs}
self.value_threshold = value_threshold

assert padded_value < 0, f"'padded_value' must be a negative number. Got {padded_value} instead."
assert padding_value < 0, f"'padding_value' must be a negative number. Got {padding_value} instead."

if isinstance(points, list):
points = torch.FloatTensor(
get_batched_padded_array(points,
new_length=config['max_num_points'],
constant_value=padded_value))
elif isinstance(points, (torch.Tensor, np.ndarray)):
constant_value=padding_value))
elif isinstance(points, np.ndarray):
points = torch.FloatTensor(points)
elif isinstance(points, torch.Tensor):
points = points.type(torch.float32)
else:
raise Exception(f"Input must be a Tensor, a numpy array or a nested list. Got {type(points)}.")

self.batch_size, self.max_num_points, self.dimension = points.shape

self.device_key = device_key
self.padded_value = padded_value
self.padding_value = padding_value
self.distinguished_points = distinguished_points

super().__init__(points, **config)
Expand All @@ -65,19 +67,19 @@ def _shift(self,
coords: List[List[int]],
axis: List[int],
inplace: Optional[bool] = True):
return shift_torch(points, coords, axis, inplace=inplace)
return shift_torch(points, coords, axis, inplace=inplace, padding_value=self.padding_value)

def _get_newton_polytope(self, points: torch.Tensor, inplace: Optional[bool] = True):
return get_newton_polytope_torch(points, inplace=inplace)
return get_newton_polytope_torch(points, inplace=inplace, padding_value=self.padding_value)

def _get_shape(self, points: torch.Tensor):
return points.shape

def _reposition(self, points: torch.Tensor, inplace: Optional[bool] = True):
return reposition_torch(points, inplace=inplace)
return reposition_torch(points, inplace=inplace, padding_value=self.padding_value)

def _rescale(self, points: torch.Tensor, inplace: Optional[bool] = True):
return points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1))
r = rescale_torch(points, inplace=inplace, padding_value=self.padding_value)

def _points_copy(self, points: torch.Tensor):
return points.clone().detach()
Expand Down
2 changes: 1 addition & 1 deletion hironaka/cpp/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# C++ modules
# C++ modules (depreciated)
This was originally written for numpy operations but is currently not in use.
We leave it here for potential future revival.
20 changes: 20 additions & 0 deletions hironaka/host.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
from itertools import combinations
from typing import Optional

import numpy as np

from .core import Points
from .policy import Policy


class Host(abc.ABC):
Expand Down Expand Up @@ -61,3 +63,21 @@ def select_coord(self, points: Points, debug=False):
result.append([0, 1])

return result


class PolicyHost(Host):
def __init__(self,
policy: Policy,
use_discrete_actions_for_host: Optional[bool] = False,
**kwargs):
self._policy = policy
self.use_discrete_actions_for_host = kwargs.get('use_discrete_actions_for_host', use_discrete_actions_for_host)

def select_coord(self, points: Points, debug=False):
features = points.get_features()

coords = self._policy.predict(features) # return multi-binary array
result = []
for b in range(coords.shape[0]):
result.append(np.where(coords[b] == 1)[0])
return result
23 changes: 0 additions & 23 deletions hironaka/policy_players/PolicyAgent.py

This file was deleted.

25 changes: 0 additions & 25 deletions hironaka/policy_players/PolicyHost.py

This file was deleted.

2 changes: 0 additions & 2 deletions hironaka/policy_players/README.md

This file was deleted.

2 changes: 0 additions & 2 deletions hironaka/policy_players/__init__.py

This file was deleted.

45 changes: 32 additions & 13 deletions hironaka/src/_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from hironaka.src import batched_coord_list_to_binary


def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[bool] = True):
def get_newton_polytope_approx_torch(points: torch.Tensor,
inplace: Optional[bool] = True,
padding_value: Optional[float] = -1.):
assert len(points.shape) == 3
batch_size, max_num_points, dimension = points.shape

Expand All @@ -27,21 +29,26 @@ def get_newton_polytope_approx_torch(points: torch.Tensor, inplace: Optional[boo
points_to_remove = ((difference >= 0) & diag_filter & filter_matrix).all(3).any(2)
points_to_remove = points_to_remove.unsqueeze(2).repeat(1, 1, dimension)

r = points * ~points_to_remove + torch.full(points.shape, padding_value) * points_to_remove

if inplace:
points[:, :, :] = points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove
points[:, :, :] = r
return None
else:
return points * ~points_to_remove + torch.full(points.shape, -1.0) * points_to_remove
return r


def get_newton_polytope_torch(points: torch.Tensor, inplace: Optional[bool] = True):
return get_newton_polytope_approx_torch(points, inplace=inplace)
def get_newton_polytope_torch(points: torch.Tensor,
inplace: Optional[bool] = True,
padding_value: Optional[float] = -1.):
return get_newton_polytope_approx_torch(points, inplace=inplace, padding_value=padding_value)


def shift_torch(points: torch.Tensor,
coord: Union[torch.Tensor, List[List[int]]],
axis: Union[torch.Tensor, List[int]],
inplace=True):
inplace: Optional[bool] = True,
padding_value: Optional[float] = -1.):
"""
note:
If coord is a list, it is assumed to be lists of chosen coordinates.
Expand Down Expand Up @@ -85,25 +92,37 @@ def shift_torch(points: torch.Tensor,
trans_matrix = trans_matrix.unsqueeze(1).repeat(1, max_num_points, 1, 1)

transformed_points = torch.matmul(trans_matrix, points.unsqueeze(3)).squeeze(3)
result = (transformed_points * available_points) + torch.full(points.shape, -1.0) * ~available_points
r = (transformed_points * available_points) + torch.full(points.shape, padding_value) * ~available_points

if inplace:
points[:, :, :] = result
points[:, :, :] = r
return None
else:
return result
return r


def reposition_torch(points: torch.Tensor, inplace: Optional[bool] = True):
def reposition_torch(points: torch.Tensor,
inplace: Optional[bool] = True,
padding_value: Optional[float] = -1.):
available_points = points.ge(0)
maximum = torch.max(points)

preprocessed = points * available_points + torch.full(points.shape, maximum + 1) * ~available_points
coordinate_minimum = torch.amin(preprocessed, 1)
unfiltered_result = points - coordinate_minimum.unsqueeze(1).repeat(1, points.shape[1], 1)
result = unfiltered_result * available_points + torch.full(points.shape, -1.0) * ~available_points
r = unfiltered_result * available_points + torch.full(points.shape, padding_value) * ~available_points
if inplace:
points[:, :, :] = result
points[:, :, :] = r
return None
else:
return result
return r


def rescale_torch(points: torch.Tensor, inplace: Optional[bool] = True, padding_value: Optional[float] = -1.):
available_points = points.ge(0)
r = points * available_points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1)) + \
padding_value * ~available_points
if inplace:
points[:, :, :] = r
else:
return r
1 change: 0 additions & 1 deletion hironaka/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .geom import *
from .search import *
Empty file removed hironaka/util/geom.py
Empty file.
4 changes: 2 additions & 2 deletions test/testPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from hironaka.core import Points
from hironaka.gameHironaka import GameHironaka
from hironaka.policy.NNPolicy import NNPolicy
from hironaka.policy_players.PolicyAgent import PolicyAgent
from hironaka.policy_players.PolicyHost import PolicyHost
from hironaka.agent import PolicyAgent
from hironaka.host import PolicyHost
from hironaka.src import generate_batch_points


Expand Down
13 changes: 12 additions & 1 deletion test/testTorchPoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,23 @@ class testTorchPoints(unittest.TestCase):
[-1., -1., -1., -1.],
[-1., -1., -1., -1.]]])

rs = torch.Tensor([[[ 0.0000, 0.2000, 0.1000, 0.1000],
[-1.0000, -1.0000, -1.0000, -1.0000],
[ 0.3000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 1.0000, 0.5000, 0.0000]],

[[ 0.0000, 0.0000, 0.4000, 1.0000],
[ 0.2000, 0.0000, 0.0000, 0.0000],
[-1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000]]])

def test_functions(self):
p = torch.FloatTensor(
[
[[1, 2, 3, 4], [2, 3, 4, 5], [4, 1, 2, 3], [1, 6, 7, 3]],
[[0, 1, 3, 5], [1, 1, 1, 1], [9, 8, 2, 1], [-1, -1, -1, -1]]
]
)

assert torch.all(get_newton_polytope_torch(p, inplace=False).eq(self.r))
get_newton_polytope_torch(p, inplace=True)
assert torch.all(p.eq(self.r))
Expand Down Expand Up @@ -72,3 +81,5 @@ def test_pointstensor(self):
assert str(pts) == str(self.r2)
pts.reposition()
assert str(pts) == str(self.r3)
pts.rescale()
assert str(pts) == str(self.rs)
5 changes: 2 additions & 3 deletions train/train_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.resolve()))

from hironaka.policy import NNPolicy
from hironaka.policy_players import PolicyHost, PolicyAgent
from hironaka.validator import HironakaValidator

import gym
Expand All @@ -16,8 +15,8 @@

from stable_baselines3 import DQN

from hironaka.agent import RandomAgent, ChooseFirstAgent
from hironaka.host import Zeillinger, RandomHost
from hironaka.agent import RandomAgent, ChooseFirstAgent, PolicyAgent
from hironaka.host import Zeillinger, RandomHost, PolicyHost

register(
id='hironaka/HironakaHost-v0',
Expand Down

0 comments on commit 5c3e2ae

Please sign in to comment.