Skip to content

Commit

Permalink
Extend Graphworld rendering to Blocksworld.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Oct 18, 2023
1 parent 2ec79de commit 78a8b98
Show file tree
Hide file tree
Showing 10 changed files with 528 additions and 256 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PDDL = "2c8894f9-daa1-498a-9e3a-26edd9623db8"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -26,6 +27,7 @@ GeometryBasics = "0.4"
GraphMakie = "0.5"
Graphs = "1.4"
Makie = "0.19.5, 0.20"
NetworkLayout = "0.4.5"
OrderedCollections = "1"
PDDL = "0.2.13"
PlanningDomains = "0.1.3"
Expand Down
2 changes: 1 addition & 1 deletion src/PDDLViz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Base: @kwdef

using PDDL, SymbolicPlanners
using Makie, GraphMakie
using Graphs, GraphMakie.NetworkLayout
using Graphs, NetworkLayout
using FileIO, Base64
using OrderedCollections
using DocStringExtensions
Expand Down
68 changes: 49 additions & 19 deletions src/renderers/graphworld/graphworld.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
export GraphworldRenderer

using GraphMakie.NetworkLayout: AbstractLayout
include("layouts.jl")

"""
GraphworldRenderer(; options...)
Customizable renderer for domains with fixed locations connected in a graph.
Customizable renderer for domains with fixed locations and movable objects
connected in a graph. The layout of the graph can be controlled with the
`graph_layout` option, which takes a function that returns an `AbstractLayout`
given the number of locations.
By default, the graph is laid out using the `StressLocSpringMov` layout, which
arranges the first `n_locs` nodes via stress minimization, and uses
spring/repulsion for the remaining nodes.
# General options
Expand All @@ -14,24 +21,42 @@ $(TYPEDFIELDS)
@kwdef mutable struct GraphworldRenderer <: Renderer
"Default figure resolution, in pixels."
resolution::Tuple{Int, Int} = (800, 800)
"Function or `AbstractLayout` that maps a graph to node locations."
graph_layout::Union{Function, AbstractLayout} = Stress()
"Whether the graph edges are directed."
is_directed::Bool = false
"Function `n_locs -> (graph -> positions)` that returns an AbstractLayout."
graph_layout::Function = n_locs -> StressLocSpringMov(n_locs=n_locs)
"Whether the edges between locations are directed."
is_loc_directed::Bool = false
"Whether the edges between movable objects are directed."
is_mov_directed::Bool = false
"Whether there are edges between movable objects."
has_mov_edges::Bool = false
"PDDL objects that correspond to fixed locations."
locations::Vector{Const} = Const[]
"PDDL object types that correspond to fixed locations."
location_types::Vector{Symbol} = [:location]
location_types::Vector{Symbol} = Symbol[]
"PDDL objects that correspond to movable objects."
movables::Vector{Const} = Const[]
"PDDL object types that correspond to movable objects."
movable_types::Vector{Symbol} = [:movable]
"Function `(dom, state, a, b) -> Bool` that checks if `(a, b)` is present."
edge_fn::Function = (d, s, a, b) -> a != b
"Function `(dom, state, a, b) -> String` that returns a label for `(a, b)`."
edge_label_fn::Function = (d, s, a, b) -> ""
"Function `(dom, state, x, loc) -> Bool` that returns if `x` is at `loc`."
at_loc_fn::Function = (d, s, x, loc) -> false
movable_types::Vector{Symbol} = Symbol[]
"Function `(dom, s, l1, l2) -> Bool` that checks if `l1` connects to `l2`."
loc_edge_fn::Function = (d, s, l1, l2) -> l1 != l2
"Function `(dom, s, l1, l2) -> String` that labels edge `(l1, l2)`."
loc_edge_label_fn::Function = (d, s, l1, l2) -> ""
"Function `(dom, s, obj, loc) -> Bool` that checks if `mov` is at `loc`."
mov_loc_edge_fn::Function = (d, s, mov, loc) -> false
"Function `(dom, s, obj, loc) -> String` that labels edge `(mov, loc)`."
mov_loc_edge_label_fn::Function = (d, s, mov, loc) -> ""
"Function `(dom, s, m1, m2) -> Bool` that checks if `m1` connects to `m2`."
mov_edge_fn::Function = (d, s, m1, m2) -> false
"Function `(dom, s, m1, m2) -> String` that labels edge `(m1, m2)`."
mov_edge_label_fn::Function = (d, s, o1, o2) -> ""
"Location object renderers, of the form `(domain, state, loc) -> Graphic`."
loc_renderers::Dict{Const, Function} = Dict{Const, Function}()
"Movable object renderers, of the form `(domain, state, obj) -> Graphic`."
mov_renderers::Dict{Const, Function} = Dict{Const, Function}()
"Per-type location renderers, of the form `(domain, state, loc) -> Graphic`."
loc_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}()
"Per-type object renderers, of the form `(domain, state, obj) -> Graphic`."
obj_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}()
loc_type_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}()
"Per-type movable renderers, of the form `(domain, state, obj) -> Graphic`."
mov_type_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}()
"Default options for graph rendering, passed to the `graphplot` recipe."
graph_options::Dict{Symbol, Any} = Dict{Symbol, Any}(
:node_size => 0.05,
Expand All @@ -40,14 +65,19 @@ $(TYPEDFIELDS)
:nlabels_align => (:center, :center),
:elabels_fontsize => 16,
)
"Default display options for axis."
axis_options::Dict{Symbol, Any} = Dict{Symbol, Any}(
:aspect => 1,
:autolimitaspect => 1,
:hidedecorations => true
)
"Default options for state rendering."
state_options::Dict{Symbol, Any} = default_state_options(GraphworldRenderer)
end

function new_canvas(renderer::GraphworldRenderer)
figure = Figure(resolution=renderer.resolution)
axis = Axis(figure[1, 1])
return Canvas(figure, axis)
return Canvas(figure)
end

include("state.jl")
Expand Down
114 changes: 114 additions & 0 deletions src/renderers/graphworld/layouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using NetworkLayout: AbstractLayout

"""
StressLocSpringMov(; kwargs...)(adj_matrix)
Returns a layout that first places the first `n_locs` nodes using stress
minimization, then places the remaining nodes using spring/repulsion.
## Keyword Arguments
- `dim = 2`, `Ptype = Float64`: Dimension and output type.
- `n_locs = 0`: Number of nodes to place using stress minimization.
- `stress_kwargs = Dict{Symbol, Any}()`: Keyword arguments for `Stress`.
- `spring_kwargs = Dict{Symbol, Any}(:C => 0.3)`: Keyword arguments for `Spring`.
"""
struct StressLocSpringMov{Dim, Ptype} <: AbstractLayout{Dim, Ptype}
n_locs::Int
stress_kwargs::Dict{Symbol, Any}
spring_kwargs::Dict{Symbol, Any}
end

function StressLocSpringMov(;
dim = 2,
Ptype = Float64,
n_locs = 0,
stress_kwargs = Dict{Symbol, Any}(),
spring_kwargs = Dict{Symbol, Any}(:C => 0.3)
)
return StressLocSpringMov{dim, Ptype}(n_locs, stress_kwargs, spring_kwargs)
end

function NetworkLayout.layout(
algo::StressLocSpringMov{Dim, Ptype}, adj_matrix::AbstractMatrix
) where {Dim, Ptype}
n_nodes = NetworkLayout.assertsquare(adj_matrix)
stress = Stress(;dim=Dim, Ptype=Ptype, algo.stress_kwargs...)
loc_positions = stress(adj_matrix[1:algo.n_locs, 1:algo.n_locs])
init_positions = resize!(copy(loc_positions), n_nodes)
for i in algo.n_locs+1:n_nodes
for j in 1:algo.n_locs
adj_matrix[i, j] == 0 && continue
init_positions[i] = loc_positions[j]
break
end
end
spring = Spring(;dim=Dim, Ptype=Ptype, pin=loc_positions,
initialpos=init_positions, algo.spring_kwargs...)
return spring(adj_matrix)
end

struct BlocksworldLayout{Ptype} <: AbstractLayout{2, Ptype}
n_locs::Int
block_width::Ptype
block_height::Ptype
block_gap::Ptype
table_height::Ptype
gripper_height::Ptype
end

function BlocksworldLayout(;
Ptype = Float64,
n_locs = 2,
block_width = Ptype(1.0),
block_height = Ptype(1.0),
block_gap = Ptype(0.5),
table_height = block_height,
gripper_height = table_height + (n_locs - 2 + 1) * block_height
)
return BlocksworldLayout{Ptype}(
n_locs,
block_width,
block_height,
block_gap,
table_height,
gripper_height
)
end

function NetworkLayout.layout(
algo::BlocksworldLayout{Ptype}, adj_matrix::AbstractMatrix
) where {Ptype}
n_nodes = NetworkLayout.assertsquare(adj_matrix)
n_blocks = n_nodes - algo.n_locs
graph = SimpleDiGraph(adj_matrix)
positions = Vector{Point2{Ptype}}(undef, n_nodes)
# Set table and gripper location
x_mid = n_blocks * (algo.block_width + algo.block_gap) / Ptype(2)
positions[1] = Point2{Ptype}(x_mid, algo.table_height/2)
positions[2] = Point2{Ptype}(x_mid, algo.gripper_height+algo.block_height/2)
# Compute base locations
x_start = (algo.block_width + algo.block_gap) / Ptype(2)
for i in 1:(algo.n_locs-2)
x = (i - 1) * (algo.block_width + algo.block_gap) + x_start
y = algo.table_height - algo.block_height / Ptype(2)
positions[2 + i] = Point2{Ptype}(x, y)
end
# Compute block locations for towers rooted at each base
for base in 3:algo.n_locs
stack = [(i, base) for i in inneighbors(graph, base)]
while !isempty(stack)
(node, parent) = pop!(stack)
x, y = positions[parent]
y += algo.block_height
positions[node] = Point2{Ptype}(x, y)
for child in inneighbors(graph, node)
push!(stack, (child, node))
end
end
end
# Compute block locations for blocks held in gripper
for node in inneighbors(graph, 2)
positions[node] = copy(positions[2])
end
return positions
end
Loading

0 comments on commit 78a8b98

Please sign in to comment.