Skip to content

Commit

Permalink
Fix bug when batch size > 1 and add training tips
Browse files Browse the repository at this point in the history
  • Loading branch information
williamshen-nz committed Sep 17, 2023
1 parent e801737 commit 7953f0e
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.idea/
__pycache__

# Fast downward and results directory
fast_downward/
results/
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ STRIPS-HGN is a framework for learning domain-independent planning heuristics co

For any issues please contact the authors of the paper. You can find our emails in the [paper](https://shen.nz/papers/shen-stripshgn-20.pdf).

**New (2023-09)** - fixed bug when using batch size > 1 and added [tips for making training more stable](#making-training-more-stable).

## Usage
### Directory Structure
- `benchmarks`: contains the PDDL domain and problems
Expand All @@ -29,7 +31,24 @@ I would recommend using a virtual environment. The entry point for training is `
- Domain-Independent Experiments
3. Optimise Hypergraph Networks

### Making Training more Stable
Training STRIPS-HGN can be quite unstable, which is partially why we used the *k* folds method during training. Here are
some tips to make training more stable:

- Remove the ReLU activation function in the output transform in `EncodeProcessDecode` (remove the `nn.ReLU(inplace=True)` from `edge_model`, `node_model`, and `global_model`).
- When the STRIPS-HGN is initialized with weights such that the output is negative, the ReLU will zero the gradients and the network will not learn. Removing the ReLU and just
keeping the linear layer seems to solve this problem.
- Note that you will need to `torch.clamp(outputs, min=0)` or `min(0, outputs)`, so the outputs are non-negative and are hence valid heuristics.
- Using a smaller learning rate such as 1e-4 (the paper used 1e-3), or automatically reduce the learning rate when the loss plateaus
(you can try `torch.optim.lr_scheduler.ReduceLROnPlateau`).

Thanks to Carlos Núñez Molina (University of Granada) and Dillon Chen (Australian National University) for these tips.
These changes mean you can try reducing the number of folds or potentially not use them at all.

### Past Updates
**17th September 2023**: Added tips for addressing unstable training and fixed bug that didn't merge `HypergraphsTuple` correctly by accumulating the node indices.
This meant that using batch size > 1 would not work properly. Thanks to Carlos Núñez Molina for raising these issues.

**Update - 19th April 2021**: apologies for the delay. Unfortunately, I have not had time to work on research so I am releasing the code-base as is.

**Update - 21st Sept 2020**: the implementation is ready but just needs some final cleaning up and testing. Due to work and other circumstances I have been unable to spend much time on research. I hope to release update this repository by [ICAPS](https://icaps20.icaps-conference.org/) (October 19-30). For those who want a copy of the code in the meantime please email me.
Expand Down
14 changes: 14 additions & 0 deletions src/strips_hgn/models/hypergraph_nets_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ def _stack_features(attr_name, force_matrix=True):
]
)

# Need to increase indices for each hypergraph based on number of nodes
# in previous hypergraphs
n_edge_cumsum = torch.cumsum(n_edge, dim=0)
n_node_cumsum = torch.cumsum(n_node, dim=0)
for idx, (n_edge_prev, n_edge_cur) in enumerate(
zip(n_edge_cumsum, n_edge_cumsum[1:])
):
receivers[n_edge_prev:n_edge_cur][
receivers[n_edge_prev:n_edge_cur] != -1
] += n_node_cumsum[idx]
senders[n_edge_prev:n_edge_cur][
senders[n_edge_prev:n_edge_cur] != -1
] += n_node_cumsum[idx]

# Check padding consistent across hypergraphs
assert len(set(h.zero_padding for h in graphs_tuple_list)) == 1
zero_padding = graphs_tuple_list[0].zero_padding
Expand Down
Empty file added tests/test_models/__init__.py
Empty file.
1 change: 1 addition & 0 deletions tests/test_models/assets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Assets provided by Carlos Núñez Molina, PhD student from the University of Granada.
31 changes: 31 additions & 0 deletions tests/test_models/assets/domain.pddl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
(define (domain blocksworld-4ops)
(:requirements :strips)
(:predicates (clear ?x)
(on-table ?x)
(arm-empty)
(holding ?x)
(on ?x ?y))

(:action pickup
:parameters (?ob)
:precondition (and (clear ?ob) (on-table ?ob) (arm-empty))
:effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob))
(not (arm-empty))))

(:action putdown
:parameters (?ob)
:precondition (holding ?ob)
:effect (and (clear ?ob) (arm-empty) (on-table ?ob)
(not (holding ?ob))))

(:action stack
:parameters (?ob ?underob)
:precondition (and (clear ?underob) (holding ?ob))
:effect (and (arm-empty) (clear ?ob) (on ?ob ?underob)
(not (clear ?underob)) (not (holding ?ob))))

(:action unstack
:parameters (?ob ?underob)
:precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty))
:effect (and (holding ?ob) (clear ?underob)
(not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty)))))
26 changes: 26 additions & 0 deletions tests/test_models/assets/problem.pddl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
(define (problem BW-rand-5)
(:domain blocksworld-4ops)
(:objects b1 b2 b3 b4 b5 )
(:init
(arm-empty)
(on b1 b2)
(on b2 b4)
(on-table b3)
(on b4 b3)
(on b5 b1)
(clear b5)
)
(:goal
(and
(on b1 b5)
(on b5 b2))
)
)


;; pddl-generators:
;; command: /home/skunk/ibm-ugr/pddl-generator/pddl_generators/blocksworld-4ops/../__blocksworld/blocksworld.sh 4ops 5 34000 /home/skunk/ibm-ugr/pddlsl/experiments/data/blocksworld/train/2e237657707f4819bb0dcc0d59c18c43.pddl
;; dict: {"blocks": 5}
;; date: 2023-08-08T04:47:15.319856
;; end:

88 changes: 88 additions & 0 deletions tests/test_models/test_hypergraph_nets_adaptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os

import pytest
import torch

from hypergraph_nets.hypergraphs import HypergraphsTuple
from strips_hgn.features.global_features import NumberOfNodesAndEdgesGlobalFeatureMapper
from strips_hgn.features.hyperedge_features import ComplexHyperedgeFeatureMapper
from strips_hgn.features.node_features import PropositionInStateAndGoal
from strips_hgn.hypergraph.delete_relaxation import DeleteRelaxationHypergraphView
from strips_hgn.models.hypergraph_nets_adaptor import merge_hypergraphs_tuple
from strips_hgn.planning.strips import _PyperplanSTRIPSProblem
from strips_hgn.workflows.base_workflow import BaseFeatureMappingWorkflow

_module_path = os.path.dirname(os.path.abspath(__file__))

_domain_path = os.path.join(_module_path, "assets/domain.pddl")
_problem_path = os.path.join(_module_path, "assets/problem.pddl")


@pytest.fixture
def hg_tuple() -> HypergraphsTuple:
max_num_add_effects = 10
max_num_preconditions = 10

# Setup STRIPS-HGN so we can get a hypergraph tuple of the initial state
strips_problem = _PyperplanSTRIPSProblem(_domain_path, _problem_path)
problem_dr_hypergraph = DeleteRelaxationHypergraphView(strips_problem)
state_hypergraph_encoder = BaseFeatureMappingWorkflow(
global_feature_mapper_cls=NumberOfNodesAndEdgesGlobalFeatureMapper,
node_feature_mapper_cls=PropositionInStateAndGoal,
hyperedge_feature_mapper_cls=ComplexHyperedgeFeatureMapper,
max_receivers=max_num_add_effects,
max_senders=max_num_preconditions,
)

hg_tuple = state_hypergraph_encoder._get_input_hypergraphs_tuple(
current_state=strips_problem.initial_state, hypergraph=problem_dr_hypergraph
)
yield hg_tuple


def test_merge_hypergraphs_tuple_single_element(hg_tuple: HypergraphsTuple):
merged_hg_tuple = merge_hypergraphs_tuple([hg_tuple])
assert torch.equal(merged_hg_tuple.edges, hg_tuple.edges)
assert torch.equal(merged_hg_tuple.nodes, hg_tuple.nodes)
assert torch.equal(merged_hg_tuple.globals, hg_tuple.globals)
assert torch.equal(merged_hg_tuple.receivers, hg_tuple.receivers)
assert torch.equal(merged_hg_tuple.senders, hg_tuple.senders)
assert torch.equal(merged_hg_tuple.n_node, hg_tuple.n_node)
assert torch.equal(merged_hg_tuple.n_edge, hg_tuple.n_edge)


def test_merge_hypergraphs_tuple_duplicate(hg_tuple: HypergraphsTuple):
"""Test duplicating the same hypergraph tuple"""
# Effectively duplicate the hypergraph tuple
merged_hg_tuple = merge_hypergraphs_tuple([hg_tuple, hg_tuple])
assert merged_hg_tuple.total_n_node == 2 * hg_tuple.total_n_node
assert merged_hg_tuple.total_n_edge == 2 * hg_tuple.total_n_edge

# Check the node indices have been accumulated correctly (i.e., we do not use the same indices for the two merged
# hypergraphs). There are twice as many nodes minus 1 to account for the -1 index used for padding (IIRC).
og_node_idxs = set(
hg_tuple.receivers.flatten().tolist() + hg_tuple.senders.flatten().tolist()
)
merged_node_idxs = set(
merged_hg_tuple.receivers.flatten().tolist()
+ merged_hg_tuple.senders.flatten().tolist()
)
assert len(merged_node_idxs) == 2 * len(og_node_idxs) - 1

# Check the node indices in the hyperedges have been accumulated correctly
duplicated_hg_receivers = merged_hg_tuple.receivers[hg_tuple.total_n_edge :]
duplicated_hg_senders = merged_hg_tuple.senders[hg_tuple.total_n_edge :]
assert (
len(duplicated_hg_receivers)
== len(duplicated_hg_senders)
== hg_tuple.total_n_edge
)

# Compute expected indices and check, need to make sure we maintain -1 padding
expected_hg_receivers = hg_tuple.receivers
expected_hg_receivers[hg_tuple.receivers != -1] += hg_tuple.total_n_node
assert torch.equal(duplicated_hg_receivers, expected_hg_receivers)

expected_hg_senders = hg_tuple.senders
expected_hg_senders[hg_tuple.senders != -1] += hg_tuple.total_n_node
assert torch.equal(duplicated_hg_senders, expected_hg_senders)

0 comments on commit 7953f0e

Please sign in to comment.