Skip to content

Commit

Permalink
Merge pull request #28 from honglu2875/training
Browse files Browse the repository at this point in the history
Training
  • Loading branch information
honglu2875 authored Jul 17, 2022
2 parents 2a8dc8a + 5c3e2ae commit d80e4e9
Show file tree
Hide file tree
Showing 46 changed files with 576 additions and 169 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ dmypy.json
# Pyre type checker
.pyre/


# High-fly cluster config
.hfai
.hfignore

# PyCharm
.idea

# Models
models/
16 changes: 16 additions & 0 deletions hironaka/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# hironaka
This is the base folder of the hironaka library.

Submodules:

- [hironaka.core](core)
- [hironaka.gym_env](gym_env)
- [hironaka.policy](policy)
- [hironaka.policy_players](policy_players)
- [hironaka.src](src)
- [hironaka.validator](validator)

Python scripts and non-essential functions:

- [.train/](../train)
- [.util/](util)
8 changes: 0 additions & 8 deletions hironaka/abs/PointsNumpy.py

This file was deleted.

1 change: 0 additions & 1 deletion hironaka/abs/__init__.py

This file was deleted.

22 changes: 21 additions & 1 deletion hironaka/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
from typing import List

import numpy as np

from .abs import Points
from .core import Points
from .policy import Policy


class Agent(abc.ABC):
Expand Down Expand Up @@ -33,3 +35,21 @@ def move(self, points: Points, coords, inplace=True):
points.shift(coords, actions)
points.get_newton_polytope()
return actions


class PolicyAgent(Agent):
def __init__(self, policy: Policy):
self._policy = policy

def move(self, points: Points, coords: List[List[int]], inplace=True):
assert len(coords) == points.batch_size # TODO: wrap the move method for the abstract "Agent" with sanity checks?

features = points.get_features()

actions = self._policy.predict((features, coords))

if inplace:
points.shift(coords, actions)
points.get_newton_polytope()

return actions
41 changes: 37 additions & 4 deletions hironaka/abs/Points.py → hironaka/core/Points.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,29 @@

import numpy as np

from hironaka.abs.PointsBase import PointsBase
from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, reposition_lst
from hironaka.core.PointsBase import PointsBase
from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, reposition_lst, \
get_newton_polytope_approx_lst


class Points(PointsBase):
"""
When dealing with small batches, small dimension and small point numbers, list is much better than numpy.
"""
config_keys = ['value_threshold']
subcls_config_keys = ['value_threshold', 'use_precise_newton_polytope']
copied_attributes = ['distinguished_points']

def __init__(self,
points: Union[List[List[List[int]]], np.ndarray],
value_threshold: Optional[int] = 1e8,
use_precise_newton_polytope: Optional[bool] = False,
distinguished_points: Optional[Union[List[int], None]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
**kwargs):
config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs}
self.value_threshold = value_threshold
self.use_precise_newton_polytope = use_precise_newton_polytope
self.distinguished_points = distinguished_points

# Be lenient and allow numpy array as input.
# The input might already be -1 padded arrays. Thus, we do a thorough check to clean that up.
Expand Down Expand Up @@ -61,7 +67,34 @@ def _reposition(self, points: Any, inplace: Optional[bool] = True):
return reposition_lst(points, inplace=inplace)

def _get_newton_polytope(self, points: Any, inplace: Optional[bool] = True):
return get_newton_polytope_lst(points, inplace=inplace, get_ended=False)
# Mark distinguished points
if self.distinguished_points is not None:
# Apply marks to the distinguished points before the operation
for b in range(self.batch_size):
if self.distinguished_points[b] is None:
continue
self.points[b][self.distinguished_points[b]].append('d')

if self.use_precise_newton_polytope:
result = get_newton_polytope_lst(points, inplace=inplace)
else:
result = get_newton_polytope_approx_lst(points, inplace=inplace, get_ended=False)

# Recover the locations of distinguished points
if self.distinguished_points is not None:
transformed_points = points if inplace else result
for b in range(self.batch_size):
if self.distinguished_points[b] is None:
continue
distinguished_point_index = None
for i in range(len(transformed_points[b])):
if transformed_points[b][i][-1] == 'd':
distinguished_point_index = i
transformed_points[b][i].pop()
break
self.distinguished_points[b] = distinguished_point_index

return result

def _get_shape(self, points: Any):
return get_shape(points)
Expand Down
30 changes: 22 additions & 8 deletions hironaka/abs/PointsBase.py → hironaka/core/PointsBase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from copy import deepcopy
from typing import Any, Optional, List


Expand Down Expand Up @@ -35,7 +36,11 @@ class PointsBase(abc.ABC):
"""
# You MUST define `config_keys` when inheriting.
# Keys in `config_keys` will be tracked when calling the `copy()` method.
config_keys: List[str]
subcls_config_keys: List[str]
# Keys in `copied_attributes` will be directly copied during `copy()`. They MUST be initialized.
copied_attributes: List[str]
# Keys in `base_attributes` will be copied. But they are shared in all subclasses and do not need to re-initialize.
base_attributes = ['ended_each_batch', 'ended', 'batch_size', 'max_num_points', 'dimension']

def __init__(self,
points: Any,
Expand All @@ -53,10 +58,6 @@ def __init__(self,
else:
self.config = kwargs

# Update keys if modified or created in subclass
for key in self.config_keys:
self.config[key] = getattr(self, key)

# Check the shape of `points`.
shape = self._get_shape(points)
if len(shape) == 2:
Expand All @@ -66,7 +67,6 @@ def __init__(self,
points = self._add_batch_axis(points)
if len(shape) != 3:
raise Exception("input dimension must be 2 or 3.")

self.points = points

self.batch_size = self.config.get('points_batch_size', shape[0])
Expand All @@ -81,6 +81,13 @@ def __init__(self,
# will also be updated on point-changing modifications including `get_newton_polytope`
self.ended_each_batch = [False] * self.batch_size

# Update keys in `self.copied_attributes`
for key in self.copied_attributes:
if hasattr(self, key):
self.config[key] = getattr(self, key)
else:
raise Exception("Must initialize keys in 'subcls_config_keys' before calling super().__init__.")

def copy(self, points: Optional[Any] = None):
"""
Copy the object.
Expand All @@ -93,7 +100,14 @@ def copy(self, points: Optional[Any] = None):
new_points = self._points_copy(self.points)
else:
new_points = self._points_copy(points)
return self.__class__(new_points, **self.config)
new_points = self.__class__(new_points, **self.config)

for key in self.copied_attributes + self.base_attributes:
if hasattr(self, key):
setattr(new_points, key, deepcopy(getattr(self, key)))
else:
raise Exception(f"Attribute {key} is not initialized.")
return new_points

def shift(self, coords: List[List[int]], axis: List[int], inplace=True):
"""
Expand Down Expand Up @@ -129,7 +143,7 @@ def get_newton_polytope(self, inplace=True):
return None
else:
new_points = self.copy(points=r)
new_points.ended_each_batch = ended_each_batch
new_points.ended_each_batch = ended_each_batch # TODO: duplicate?
new_points.ended = ended
return new_points

Expand Down
8 changes: 8 additions & 0 deletions hironaka/core/PointsNumpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from hironaka.core.PointsBase import PointsBase


class PointsNumpy(PointsBase): # TODO:INCOMPLETE
"""
Storing points using numpy arrays.
"""
pass
31 changes: 18 additions & 13 deletions hironaka/abs/PointsTensor.py → hironaka/core/PointsTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,45 @@
import numpy as np
import torch

from hironaka.abs.PointsBase import PointsBase
from hironaka.src import shift_lst, get_newton_polytope_lst, get_shape, scale_points, get_batched_padded_array
from hironaka.src._torch_ops import shift_torch, get_newton_polytope_torch, reposition_torch
from hironaka.core.PointsBase import PointsBase
from hironaka.src import get_batched_padded_array, rescale_torch
from hironaka.src import shift_torch, get_newton_polytope_torch, reposition_torch


class PointsTensor(PointsBase):
config_keys = ['value_threshold', 'device_key', 'padded_value']
subcls_config_keys = ['value_threshold', 'device_key', 'padding_value']
copied_attributes = ['distinguished_points']

def __init__(self,
points: Union[torch.Tensor, List[List[List[int]]], np.ndarray],
value_threshold: Optional[int] = 1e8,
device_key: Optional[str] = 'cpu',
padded_value: Optional[float] = -1.0,
padding_value: Optional[float] = -1.0,
distinguished_points: Optional[List[int]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
**kwargs):
config = kwargs if config_kwargs is None else {**config_kwargs, **kwargs}
self.value_threshold = value_threshold

# It's better to require a fixed shape of the tensor implementation.
assert padding_value < 0, f"'padding_value' must be a negative number. Got {padding_value} instead."

if isinstance(points, list):
points = torch.FloatTensor(
get_batched_padded_array(points,
new_length=config['max_num_points'],
constant_value=config.get('padded_value', -1)))
elif isinstance(points, (torch.Tensor, np.ndarray)):
constant_value=padding_value))
elif isinstance(points, np.ndarray):
points = torch.FloatTensor(points)
elif isinstance(points, torch.Tensor):
points = points.type(torch.float32)
else:
raise Exception(f"Input must be a Tensor, a numpy array or a nested list. Got {type(points)}.")

self.batch_size, self.max_num_points, self.dimension = points.shape

self.device_key = device_key
self.padded_value = padded_value
self.padding_value = padding_value
self.distinguished_points = distinguished_points

super().__init__(points, **config)
self.device = torch.device(self.device_key)
Expand All @@ -62,19 +67,19 @@ def _shift(self,
coords: List[List[int]],
axis: List[int],
inplace: Optional[bool] = True):
return shift_torch(points, coords, axis, inplace=inplace)
return shift_torch(points, coords, axis, inplace=inplace, padding_value=self.padding_value)

def _get_newton_polytope(self, points: torch.Tensor, inplace: Optional[bool] = True):
return get_newton_polytope_torch(points, inplace=inplace)
return get_newton_polytope_torch(points, inplace=inplace, padding_value=self.padding_value)

def _get_shape(self, points: torch.Tensor):
return points.shape

def _reposition(self, points: torch.Tensor, inplace: Optional[bool] = True):
return reposition_torch(points, inplace=inplace)
return reposition_torch(points, inplace=inplace, padding_value=self.padding_value)

def _rescale(self, points: torch.Tensor, inplace: Optional[bool] = True):
return points / torch.reshape(torch.amax(points, (1, 2)), (-1, 1, 1))
r = rescale_torch(points, inplace=inplace, padding_value=self.padding_value)

def _points_copy(self, points: torch.Tensor):
return points.clone().detach()
Expand Down
64 changes: 64 additions & 0 deletions hironaka/core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# hironaka.core
This is the core functionality whose classes
- save collections of points,
- perform transformations (Newton polytope, shift, rescale, etc.),
- provide features, states, etc.

The base class is `PointsBase`, and the subclasses are currently `Points` and `PointsTensor`.

`PointsNumpy` is currently unnecessary and the implementation is postponed.

## .PointsBase
This interface is an abstraction of collection of points used in the Hironaka games and its variations.

A couple important notes:
- The points are stored as 3d objects (nested lists, tensors, etc.). The 3 axis represent:
- (batches, points, point coordinates)
- Subclasses must define `subcls_config_keys, copied_attributes`.
- `subcls_config_keys` defines the config keys that will be tracked and initialized when creating a copy object in `.copy()` method. In other words, they are configs that stay unchanged throughout space transformations.
- `copied_attributes` defines the keys of the attributes that will be directly copied after initialization in `.copy()`. In other words, they may be changed during space transformations and is partially the information of the state (e.g., `.ended`).
- Shape of points can be specified with config parameters `points_batch_size, dimension, max_number_points`. If they are not given, it will look over the point data and initialize `.batch_size, .dimension, .max_num_points` attributes.

Must implement:
`
_get_shape
_get_newton_polytope
_reposition
_shift
_rescale
_point_copy
_add_batch_axis
_get_batch_ended
`

Feel free to override:
`
get_features
_get_max_num_points
`

## .Points
It stores the points in nested lists and perform list-based transformations. The nested lists do not have to be of uniform shape (not padded).
For example:

```
[
[
(7, 5, 3, 8), (8, 9, 8, 18), (8, 3, 17, 8),
(11, 11, 1, 19), (11, 12, 18, 6), (16, 11, 5, 6)
],
[
(0, 1, 0, 1), (0, 2, 0, 0), (1, 0, 0, 1),
(1, 0, 1, 0), (1, 1, 0, 0)
]
]
```

This is a batch of 2 separate sets of points. They are of different sizes.

## .PointsTensor
It performs everything using PyTorch tensors.

Major differences between `PointsTensor` and `Points`
- `PointsTensor.points` have a fixed shape. Removed points will be replaced by `padded_value` which must be a negative number.
- In `Points.points`, the order of points may change during `get_newton_polytope()`. But for `PointsTensor.points`, the order will never change. When points are removed, we still maintain the original orders without moving surviving points next to each other. As a result, `distinguished_points` marks the distinguished point in each set of points, but is never changed under any operations.
2 changes: 2 additions & 0 deletions hironaka/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .Points import *
from .PointsTensor import *
3 changes: 3 additions & 0 deletions hironaka/cpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# C++ modules (depreciated)
This was originally written for numpy operations but is currently not in use.
We leave it here for potential future revival.
2 changes: 1 addition & 1 deletion hironaka/gameHironaka.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Optional, Union

from .abs import Points
from .core import Points
from .agent import Agent
from .game import Game
from .host import Host
Expand Down
Loading

0 comments on commit d80e4e9

Please sign in to comment.