-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bug when batch size > 1 and add training tips
- Loading branch information
1 parent
e801737
commit 7953f0e
Showing
8 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.idea/ | ||
__pycache__ | ||
|
||
# Fast downward and results directory | ||
fast_downward/ | ||
results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Assets provided by Carlos Núñez Molina, PhD student from the University of Granada. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
(define (domain blocksworld-4ops) | ||
(:requirements :strips) | ||
(:predicates (clear ?x) | ||
(on-table ?x) | ||
(arm-empty) | ||
(holding ?x) | ||
(on ?x ?y)) | ||
|
||
(:action pickup | ||
:parameters (?ob) | ||
:precondition (and (clear ?ob) (on-table ?ob) (arm-empty)) | ||
:effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob)) | ||
(not (arm-empty)))) | ||
|
||
(:action putdown | ||
:parameters (?ob) | ||
:precondition (holding ?ob) | ||
:effect (and (clear ?ob) (arm-empty) (on-table ?ob) | ||
(not (holding ?ob)))) | ||
|
||
(:action stack | ||
:parameters (?ob ?underob) | ||
:precondition (and (clear ?underob) (holding ?ob)) | ||
:effect (and (arm-empty) (clear ?ob) (on ?ob ?underob) | ||
(not (clear ?underob)) (not (holding ?ob)))) | ||
|
||
(:action unstack | ||
:parameters (?ob ?underob) | ||
:precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty)) | ||
:effect (and (holding ?ob) (clear ?underob) | ||
(not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty))))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
(define (problem BW-rand-5) | ||
(:domain blocksworld-4ops) | ||
(:objects b1 b2 b3 b4 b5 ) | ||
(:init | ||
(arm-empty) | ||
(on b1 b2) | ||
(on b2 b4) | ||
(on-table b3) | ||
(on b4 b3) | ||
(on b5 b1) | ||
(clear b5) | ||
) | ||
(:goal | ||
(and | ||
(on b1 b5) | ||
(on b5 b2)) | ||
) | ||
) | ||
|
||
|
||
;; pddl-generators: | ||
;; command: /home/skunk/ibm-ugr/pddl-generator/pddl_generators/blocksworld-4ops/../__blocksworld/blocksworld.sh 4ops 5 34000 /home/skunk/ibm-ugr/pddlsl/experiments/data/blocksworld/train/2e237657707f4819bb0dcc0d59c18c43.pddl | ||
;; dict: {"blocks": 5} | ||
;; date: 2023-08-08T04:47:15.319856 | ||
;; end: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
from hypergraph_nets.hypergraphs import HypergraphsTuple | ||
from strips_hgn.features.global_features import NumberOfNodesAndEdgesGlobalFeatureMapper | ||
from strips_hgn.features.hyperedge_features import ComplexHyperedgeFeatureMapper | ||
from strips_hgn.features.node_features import PropositionInStateAndGoal | ||
from strips_hgn.hypergraph.delete_relaxation import DeleteRelaxationHypergraphView | ||
from strips_hgn.models.hypergraph_nets_adaptor import merge_hypergraphs_tuple | ||
from strips_hgn.planning.strips import _PyperplanSTRIPSProblem | ||
from strips_hgn.workflows.base_workflow import BaseFeatureMappingWorkflow | ||
|
||
_module_path = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
_domain_path = os.path.join(_module_path, "assets/domain.pddl") | ||
_problem_path = os.path.join(_module_path, "assets/problem.pddl") | ||
|
||
|
||
@pytest.fixture | ||
def hg_tuple() -> HypergraphsTuple: | ||
max_num_add_effects = 10 | ||
max_num_preconditions = 10 | ||
|
||
# Setup STRIPS-HGN so we can get a hypergraph tuple of the initial state | ||
strips_problem = _PyperplanSTRIPSProblem(_domain_path, _problem_path) | ||
problem_dr_hypergraph = DeleteRelaxationHypergraphView(strips_problem) | ||
state_hypergraph_encoder = BaseFeatureMappingWorkflow( | ||
global_feature_mapper_cls=NumberOfNodesAndEdgesGlobalFeatureMapper, | ||
node_feature_mapper_cls=PropositionInStateAndGoal, | ||
hyperedge_feature_mapper_cls=ComplexHyperedgeFeatureMapper, | ||
max_receivers=max_num_add_effects, | ||
max_senders=max_num_preconditions, | ||
) | ||
|
||
hg_tuple = state_hypergraph_encoder._get_input_hypergraphs_tuple( | ||
current_state=strips_problem.initial_state, hypergraph=problem_dr_hypergraph | ||
) | ||
yield hg_tuple | ||
|
||
|
||
def test_merge_hypergraphs_tuple_single_element(hg_tuple: HypergraphsTuple): | ||
merged_hg_tuple = merge_hypergraphs_tuple([hg_tuple]) | ||
assert torch.equal(merged_hg_tuple.edges, hg_tuple.edges) | ||
assert torch.equal(merged_hg_tuple.nodes, hg_tuple.nodes) | ||
assert torch.equal(merged_hg_tuple.globals, hg_tuple.globals) | ||
assert torch.equal(merged_hg_tuple.receivers, hg_tuple.receivers) | ||
assert torch.equal(merged_hg_tuple.senders, hg_tuple.senders) | ||
assert torch.equal(merged_hg_tuple.n_node, hg_tuple.n_node) | ||
assert torch.equal(merged_hg_tuple.n_edge, hg_tuple.n_edge) | ||
|
||
|
||
def test_merge_hypergraphs_tuple_duplicate(hg_tuple: HypergraphsTuple): | ||
"""Test duplicating the same hypergraph tuple""" | ||
# Effectively duplicate the hypergraph tuple | ||
merged_hg_tuple = merge_hypergraphs_tuple([hg_tuple, hg_tuple]) | ||
assert merged_hg_tuple.total_n_node == 2 * hg_tuple.total_n_node | ||
assert merged_hg_tuple.total_n_edge == 2 * hg_tuple.total_n_edge | ||
|
||
# Check the node indices have been accumulated correctly (i.e., we do not use the same indices for the two merged | ||
# hypergraphs). There are twice as many nodes minus 1 to account for the -1 index used for padding (IIRC). | ||
og_node_idxs = set( | ||
hg_tuple.receivers.flatten().tolist() + hg_tuple.senders.flatten().tolist() | ||
) | ||
merged_node_idxs = set( | ||
merged_hg_tuple.receivers.flatten().tolist() | ||
+ merged_hg_tuple.senders.flatten().tolist() | ||
) | ||
assert len(merged_node_idxs) == 2 * len(og_node_idxs) - 1 | ||
|
||
# Check the node indices in the hyperedges have been accumulated correctly | ||
duplicated_hg_receivers = merged_hg_tuple.receivers[hg_tuple.total_n_edge :] | ||
duplicated_hg_senders = merged_hg_tuple.senders[hg_tuple.total_n_edge :] | ||
assert ( | ||
len(duplicated_hg_receivers) | ||
== len(duplicated_hg_senders) | ||
== hg_tuple.total_n_edge | ||
) | ||
|
||
# Compute expected indices and check, need to make sure we maintain -1 padding | ||
expected_hg_receivers = hg_tuple.receivers | ||
expected_hg_receivers[hg_tuple.receivers != -1] += hg_tuple.total_n_node | ||
assert torch.equal(duplicated_hg_receivers, expected_hg_receivers) | ||
|
||
expected_hg_senders = hg_tuple.senders | ||
expected_hg_senders[hg_tuple.senders != -1] += hg_tuple.total_n_node | ||
assert torch.equal(duplicated_hg_senders, expected_hg_senders) |