From 78a8b983fbc85c18a3f5b4b2dbf95c4923c09735 Mon Sep 17 00:00:00 2001 From: Xuan Date: Wed, 18 Oct 2023 00:03:05 -0400 Subject: [PATCH] Extend Graphworld rendering to Blocksworld. --- Project.toml | 2 + src/PDDLViz.jl | 2 +- src/renderers/graphworld/graphworld.jl | 68 ++++++++++---- src/renderers/graphworld/layouts.jl | 114 ++++++++++++++++++++++++ src/renderers/graphworld/state.jl | 118 ++++++++++++++----------- test/graphworld/blocksworld.jl | 93 +++++++++++++++++++ test/graphworld/test.jl | 82 ++--------------- test/graphworld/zeno_travel.jl | 82 +++++++++++++++++ test/gridworld/doors_keys_gems.jl | 110 +++++++++++++++++++++++ test/gridworld/test.jl | 113 +---------------------- 10 files changed, 528 insertions(+), 256 deletions(-) create mode 100644 src/renderers/graphworld/layouts.jl create mode 100644 test/graphworld/blocksworld.jl create mode 100644 test/graphworld/zeno_travel.jl create mode 100644 test/gridworld/doors_keys_gems.jl diff --git a/Project.toml b/Project.toml index f7f8060..fe57e5d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PDDL = "2c8894f9-daa1-498a-9e3a-26edd9623db8" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -26,6 +27,7 @@ GeometryBasics = "0.4" GraphMakie = "0.5" Graphs = "1.4" Makie = "0.19.5, 0.20" +NetworkLayout = "0.4.5" OrderedCollections = "1" PDDL = "0.2.13" PlanningDomains = "0.1.3" diff --git a/src/PDDLViz.jl b/src/PDDLViz.jl index 47b85a9..a5388fd 100644 --- a/src/PDDLViz.jl +++ b/src/PDDLViz.jl @@ -4,7 +4,7 @@ using Base: @kwdef using PDDL, SymbolicPlanners using Makie, GraphMakie -using Graphs, GraphMakie.NetworkLayout +using Graphs, NetworkLayout using FileIO, Base64 using OrderedCollections using DocStringExtensions diff --git a/src/renderers/graphworld/graphworld.jl b/src/renderers/graphworld/graphworld.jl index 3bfb210..6c2d8ae 100644 --- a/src/renderers/graphworld/graphworld.jl +++ b/src/renderers/graphworld/graphworld.jl @@ -1,11 +1,18 @@ export GraphworldRenderer -using GraphMakie.NetworkLayout: AbstractLayout +include("layouts.jl") """ GraphworldRenderer(; options...) -Customizable renderer for domains with fixed locations connected in a graph. +Customizable renderer for domains with fixed locations and movable objects +connected in a graph. The layout of the graph can be controlled with the +`graph_layout` option, which takes a function that returns an `AbstractLayout` +given the number of locations. + +By default, the graph is laid out using the `StressLocSpringMov` layout, which +arranges the first `n_locs` nodes via stress minimization, and uses +spring/repulsion for the remaining nodes. # General options @@ -14,24 +21,42 @@ $(TYPEDFIELDS) @kwdef mutable struct GraphworldRenderer <: Renderer "Default figure resolution, in pixels." resolution::Tuple{Int, Int} = (800, 800) - "Function or `AbstractLayout` that maps a graph to node locations." - graph_layout::Union{Function, AbstractLayout} = Stress() - "Whether the graph edges are directed." - is_directed::Bool = false + "Function `n_locs -> (graph -> positions)` that returns an AbstractLayout." + graph_layout::Function = n_locs -> StressLocSpringMov(n_locs=n_locs) + "Whether the edges between locations are directed." + is_loc_directed::Bool = false + "Whether the edges between movable objects are directed." + is_mov_directed::Bool = false + "Whether there are edges between movable objects." + has_mov_edges::Bool = false + "PDDL objects that correspond to fixed locations." + locations::Vector{Const} = Const[] "PDDL object types that correspond to fixed locations." - location_types::Vector{Symbol} = [:location] + location_types::Vector{Symbol} = Symbol[] + "PDDL objects that correspond to movable objects." + movables::Vector{Const} = Const[] "PDDL object types that correspond to movable objects." - movable_types::Vector{Symbol} = [:movable] - "Function `(dom, state, a, b) -> Bool` that checks if `(a, b)` is present." - edge_fn::Function = (d, s, a, b) -> a != b - "Function `(dom, state, a, b) -> String` that returns a label for `(a, b)`." - edge_label_fn::Function = (d, s, a, b) -> "" - "Function `(dom, state, x, loc) -> Bool` that returns if `x` is at `loc`." - at_loc_fn::Function = (d, s, x, loc) -> false + movable_types::Vector{Symbol} = Symbol[] + "Function `(dom, s, l1, l2) -> Bool` that checks if `l1` connects to `l2`." + loc_edge_fn::Function = (d, s, l1, l2) -> l1 != l2 + "Function `(dom, s, l1, l2) -> String` that labels edge `(l1, l2)`." + loc_edge_label_fn::Function = (d, s, l1, l2) -> "" + "Function `(dom, s, obj, loc) -> Bool` that checks if `mov` is at `loc`." + mov_loc_edge_fn::Function = (d, s, mov, loc) -> false + "Function `(dom, s, obj, loc) -> String` that labels edge `(mov, loc)`." + mov_loc_edge_label_fn::Function = (d, s, mov, loc) -> "" + "Function `(dom, s, m1, m2) -> Bool` that checks if `m1` connects to `m2`." + mov_edge_fn::Function = (d, s, m1, m2) -> false + "Function `(dom, s, m1, m2) -> String` that labels edge `(m1, m2)`." + mov_edge_label_fn::Function = (d, s, o1, o2) -> "" + "Location object renderers, of the form `(domain, state, loc) -> Graphic`." + loc_renderers::Dict{Const, Function} = Dict{Const, Function}() + "Movable object renderers, of the form `(domain, state, obj) -> Graphic`." + mov_renderers::Dict{Const, Function} = Dict{Const, Function}() "Per-type location renderers, of the form `(domain, state, loc) -> Graphic`." - loc_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}() - "Per-type object renderers, of the form `(domain, state, obj) -> Graphic`." - obj_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}() + loc_type_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}() + "Per-type movable renderers, of the form `(domain, state, obj) -> Graphic`." + mov_type_renderers::Dict{Symbol, Function} = Dict{Symbol, Function}() "Default options for graph rendering, passed to the `graphplot` recipe." graph_options::Dict{Symbol, Any} = Dict{Symbol, Any}( :node_size => 0.05, @@ -40,14 +65,19 @@ $(TYPEDFIELDS) :nlabels_align => (:center, :center), :elabels_fontsize => 16, ) + "Default display options for axis." + axis_options::Dict{Symbol, Any} = Dict{Symbol, Any}( + :aspect => 1, + :autolimitaspect => 1, + :hidedecorations => true + ) "Default options for state rendering." state_options::Dict{Symbol, Any} = default_state_options(GraphworldRenderer) end function new_canvas(renderer::GraphworldRenderer) figure = Figure(resolution=renderer.resolution) - axis = Axis(figure[1, 1]) - return Canvas(figure, axis) + return Canvas(figure) end include("state.jl") diff --git a/src/renderers/graphworld/layouts.jl b/src/renderers/graphworld/layouts.jl new file mode 100644 index 0000000..2ad2b77 --- /dev/null +++ b/src/renderers/graphworld/layouts.jl @@ -0,0 +1,114 @@ +using NetworkLayout: AbstractLayout + +""" + StressLocSpringMov(; kwargs...)(adj_matrix) + +Returns a layout that first places the first `n_locs` nodes using stress +minimization, then places the remaining nodes using spring/repulsion. + +## Keyword Arguments +- `dim = 2`, `Ptype = Float64`: Dimension and output type. +- `n_locs = 0`: Number of nodes to place using stress minimization. +- `stress_kwargs = Dict{Symbol, Any}()`: Keyword arguments for `Stress`. +- `spring_kwargs = Dict{Symbol, Any}(:C => 0.3)`: Keyword arguments for `Spring`. +""" +struct StressLocSpringMov{Dim, Ptype} <: AbstractLayout{Dim, Ptype} + n_locs::Int + stress_kwargs::Dict{Symbol, Any} + spring_kwargs::Dict{Symbol, Any} +end + +function StressLocSpringMov(; + dim = 2, + Ptype = Float64, + n_locs = 0, + stress_kwargs = Dict{Symbol, Any}(), + spring_kwargs = Dict{Symbol, Any}(:C => 0.3) +) + return StressLocSpringMov{dim, Ptype}(n_locs, stress_kwargs, spring_kwargs) +end + +function NetworkLayout.layout( + algo::StressLocSpringMov{Dim, Ptype}, adj_matrix::AbstractMatrix +) where {Dim, Ptype} + n_nodes = NetworkLayout.assertsquare(adj_matrix) + stress = Stress(;dim=Dim, Ptype=Ptype, algo.stress_kwargs...) + loc_positions = stress(adj_matrix[1:algo.n_locs, 1:algo.n_locs]) + init_positions = resize!(copy(loc_positions), n_nodes) + for i in algo.n_locs+1:n_nodes + for j in 1:algo.n_locs + adj_matrix[i, j] == 0 && continue + init_positions[i] = loc_positions[j] + break + end + end + spring = Spring(;dim=Dim, Ptype=Ptype, pin=loc_positions, + initialpos=init_positions, algo.spring_kwargs...) + return spring(adj_matrix) +end + +struct BlocksworldLayout{Ptype} <: AbstractLayout{2, Ptype} + n_locs::Int + block_width::Ptype + block_height::Ptype + block_gap::Ptype + table_height::Ptype + gripper_height::Ptype +end + +function BlocksworldLayout(; + Ptype = Float64, + n_locs = 2, + block_width = Ptype(1.0), + block_height = Ptype(1.0), + block_gap = Ptype(0.5), + table_height = block_height, + gripper_height = table_height + (n_locs - 2 + 1) * block_height +) + return BlocksworldLayout{Ptype}( + n_locs, + block_width, + block_height, + block_gap, + table_height, + gripper_height + ) +end + +function NetworkLayout.layout( + algo::BlocksworldLayout{Ptype}, adj_matrix::AbstractMatrix +) where {Ptype} + n_nodes = NetworkLayout.assertsquare(adj_matrix) + n_blocks = n_nodes - algo.n_locs + graph = SimpleDiGraph(adj_matrix) + positions = Vector{Point2{Ptype}}(undef, n_nodes) + # Set table and gripper location + x_mid = n_blocks * (algo.block_width + algo.block_gap) / Ptype(2) + positions[1] = Point2{Ptype}(x_mid, algo.table_height/2) + positions[2] = Point2{Ptype}(x_mid, algo.gripper_height+algo.block_height/2) + # Compute base locations + x_start = (algo.block_width + algo.block_gap) / Ptype(2) + for i in 1:(algo.n_locs-2) + x = (i - 1) * (algo.block_width + algo.block_gap) + x_start + y = algo.table_height - algo.block_height / Ptype(2) + positions[2 + i] = Point2{Ptype}(x, y) + end + # Compute block locations for towers rooted at each base + for base in 3:algo.n_locs + stack = [(i, base) for i in inneighbors(graph, base)] + while !isempty(stack) + (node, parent) = pop!(stack) + x, y = positions[parent] + y += algo.block_height + positions[node] = Point2{Ptype}(x, y) + for child in inneighbors(graph, node) + push!(stack, (child, node)) + end + end + end + # Compute block locations for blocks held in gripper + for node in inneighbors(graph, 2) + positions[node] = copy(positions[2]) + end + return positions +end diff --git a/src/renderers/graphworld/state.jl b/src/renderers/graphworld/state.jl index 04e27a7..12d288f 100644 --- a/src/renderers/graphworld/state.jl +++ b/src/renderers/graphworld/state.jl @@ -11,53 +11,55 @@ function render_state!( end # Extract or construct main axis ax = get(canvas.blocks, 1) do - _ax = Axis(canvas.layout[1,1], aspect=1) + axis_options = copy(renderer.axis_options) + delete!(axis_options, :hidedecorations) + _ax = Axis(canvas.layout[1, 1]; axis_options...) push!(canvas.blocks, _ax) return _ax end # Extract objects from state - locations = reduce(vcat, [PDDL.get_objects(domain, state[], t) - for t in renderer.location_types]) - movables = reduce(vcat, [PDDL.get_objects(domain, state[], t) - for t in renderer.movable_types]) + locations = [sort!(PDDL.get_objects(domain, state[], t), by=string) + for t in renderer.location_types] + locations = prepend!(reduce(vcat, locations, init=Const[]), renderer.locations) + movables = [sort!(PDDL.get_objects(domain, state[], t), by=string) + for t in renderer.movable_types] + movables = prepend!(reduce(vcat, movables, init=Const[]), renderer.movables) # Build static location graph + is_directed = renderer.is_loc_directed || renderer.is_mov_directed n_locs = length(locations) - loc_graph = renderer.is_directed ? - SimpleDiGraph(n_locs) : SimpleGraph(n_locs) + loc_graph = is_directed ? SimpleDiGraph(n_locs) : SimpleGraph(n_locs) for (i, a) in enumerate(locations), (j, b) in enumerate(locations) - if renderer.edge_fn(domain, state[], a, b) - add_edge!(loc_graph, i, j) - end + renderer.loc_edge_fn(domain, state[], a, b) || continue + add_edge!(loc_graph, i, j) + is_directed && !renderer.is_loc_directed || continue + add_edge!(loc_graph, j, i) end # Add movable objects to graph graph = @lift begin g = copy(loc_graph) - for obj in movables + # Add edges between locations and movable objects + for (i, mov) in enumerate(movables) add_vertex!(g) - for (i, loc) in enumerate(locations) - if renderer.at_loc_fn(domain, $state, obj, loc) - add_edge!(g, i, nv(g)) - continue + for (j, loc) in enumerate(locations) + if renderer.mov_loc_edge_fn(domain, $state, mov, loc) + add_edge!(g, n_locs + i, j) + break end end end - g - end - # Construct layout for graph including movable objects - loc_pos = renderer.graph_layout(loc_graph) - layout = @lift begin - init_pos = copy(loc_pos) - for i in n_locs+1:nv($graph) - nbs = inneighbors($graph, i) - if isempty(nbs) - push!(init_pos, Point2f(0, 0)) - else - push!(init_pos, loc_pos[nbs[1]]) + # Add edges between movable objects + if renderer.has_mov_edges + for (i, a) in enumerate(movables), (j, b) in enumerate(movables) + renderer.mov_edge_fn(domain, $state, a, b) || continue + add_edge!(g, n_locs + i, n_locs + j) + is_directed && !renderer.is_mov_directed || continue + add_edge!(g, n_locs + j, n_locs + i) end end - Spring(; pin=loc_pos, initialpos=init_pos, - C=get(options, :movable_spring_constant, 0.3)) + g end + # Construct layout for graph including movable objects + layout = renderer.graph_layout(n_locs) # Define node and edge labels loc_labels = get(options, :show_location_labels, true) ? string.(locations) : fill("", length(locations)) @@ -73,11 +75,18 @@ function render_state!( end labels = Vector{String}(undef, n_edges) for (i, e) in enumerate(edges($graph)) - if e.src > n_locs || e.dst > n_locs - labels[i] = "" - else + if e.src <= n_locs && e.dst <= n_locs a, b = locations[e.src], locations[e.dst] - labels[i] = renderer.edge_label_fn(domain, $state, a, b) + labels[i] = renderer.loc_edge_label_fn(domain, $state, a, b) + elseif e.src > n_locs && e.dst > n_locs + a = movables[e.src - n_locs] + b = movables[e.dst - n_locs] + labels[i] = renderer.mov_edge_label_fn(domain, $state, a, b) + elseif e.src <= n_locs && e.dst > n_locs + a = movables[e.dst - n_locs] + b = locations[e.src] + labels[i] = + renderer.mov_loc_edge_label_fn(domain, $state, a, b) end end labels @@ -100,19 +109,27 @@ function render_state!( nlabels=node_labels, elabels=edge_labels, edge_color=edge_colors, renderer.graph_options...) canvas.plots[:graph] = gp + canvas.observables[:node_pos] = gp[:node_pos] # Update node label offsets - offset_mult = get(options, :label_offset_mult, 0.2) + label_offset = get(options, :label_offset, 0.15) map!(gp.nlabels_offset, gp.node_pos) do node_pos mean_pos = sum(node_pos[1:n_locs]) / n_locs - offsets = offset_mult .* (node_pos .- mean_pos) + dir = node_pos .- mean_pos + mag = [GeometryBasics.norm(d) for d in dir] + dir = dir ./ mag + offsets = label_offset .* dir return offsets end # Render location graphics if get(options, :show_location_graphics, true) for (i, loc) in enumerate(locations) - type = PDDL.get_objtype(state[], loc) - r = get(renderer.loc_renderers, type, nothing) - r === nothing && continue + r = get(renderer.loc_renderers, loc, nothing) + if r === nothing + loc in PDDL.get_objects(state[]) || continue + type = PDDL.get_objtype(state[], loc) + r = get(renderer.loc_type_renderers, type, nothing) + r === nothing && continue + end graphic = @lift begin pos = $(gp.node_pos)[i] translate(r(domain, $state, loc), pos[1], pos[2]) @@ -124,9 +141,13 @@ function render_state!( # Render movable object graphics if get(options, :show_movable_graphics, true) for (i, obj) in enumerate(movables) - type = PDDL.get_objtype(state[], obj) - r = get(renderer.obj_renderers, type, nothing) - r === nothing && continue + r = get(renderer.mov_renderers, obj, nothing) + if r === nothing + obj in PDDL.get_objects(state[]) || continue + type = PDDL.get_objtype(state[], obj) + r = get(renderer.mov_type_renderers, type, nothing) + r === nothing && continue + end graphic = @lift begin pos = $(gp.node_pos)[n_locs + i] translate(r(domain, $state, obj), pos[1], pos[2]) @@ -135,10 +156,10 @@ function render_state!( canvas.plots[Symbol("$(obj)_graphic")] = plt end end - # Final axis modifications - hidedecorations!(ax) - autolimits!(ax) - ax.aspect = 1 + # Hide decorations if flag is specified + if get(renderer.axis_options, :hidedecorations, true) + hidedecorations!(ax) + end return canvas end @@ -153,9 +174,7 @@ end - `movable_node_color = :gray`: Color of movable object nodes. - `movable_edge_color = (:mediumpurple, 0.75)`: Color of edges between locations and movable objects. -- `movable_spring_constant = 0.3`: Controls how much movable objects are - repelled from other nodes. -- `label_offset_mult = 0.2`: Multiplier for the offset of labels from their +- `label_offset = 0.15`: How much labels are offset from the center of their corresponding objects. Larger values move the labels further away. """ default_state_options(R::Type{GraphworldRenderer}) = Dict{Symbol,Any}( @@ -168,6 +187,5 @@ default_state_options(R::Type{GraphworldRenderer}) = Dict{Symbol,Any}( :location_edge_color => :black, :movable_node_color => :gray, :movable_edge_color => (:mediumpurple, 0.75), - :movable_spring_constant => 0.3, - :label_offset_mult => 0.2, + :label_offset => 0.15 ) diff --git a/test/graphworld/blocksworld.jl b/test/graphworld/blocksworld.jl new file mode 100644 index 0000000..4dc7682 --- /dev/null +++ b/test/graphworld/blocksworld.jl @@ -0,0 +1,93 @@ +using PDDLViz, GLMakie, GraphMakie +using PDDL, SymbolicPlanners, PlanningDomains + +# Load blocksworld domain and problem +domain = load_domain(:blocksworld) +problem = load_problem(:blocksworld, 5) + +# Construct initial state from domain and problem +state = initstate(domain, problem) + +# Construct graphworld renderer +cmap = Makie.colorschemes[:plasma][1:8:256] +renderer = PDDLViz.GraphworldRenderer( + graph_layout = n_locs -> PDDLViz.BlocksworldLayout(n_locs=n_locs), + is_loc_directed = true, + is_mov_directed = true, + has_mov_edges = true, + locations = [pddl"(table)", pddl"(gripper)"], + location_types = [:block], + movable_types = [:block], + loc_edge_fn = (d, s, a, b) -> false, + mov_loc_edge_fn = (d, s, x, loc) -> begin + if x == loc + s[Compound(:ontable, [x])] + elseif loc.name == :gripper + s[Compound(:holding, [x])] + else + false + end + end, + mov_edge_fn = (d, s, x, y) -> s[Compound(:on, [x, y])], + loc_renderers = Dict{Const, Function}( + pddl"(table)" => (d, s, loc) -> begin + n_blocks = length(PDDL.get_objects(s, :block)) + width = 1.5 * n_blocks + PDDLViz.RectShape( + 0.0, 0.0, width, 1.0, + color = :grey60, strokewidth=2.0 + ) + end, + ), + mov_type_renderers = Dict{Symbol, Function}( + :block => (d, s, o) -> MultiGraphic( + PDDLViz.SquareShape( + 0.0, 0.0, 1.0, + color=cmap[mod(hash(o.name), length(cmap))+1], + strokewidth=2.0 + ), + TextGraphic( + string(o.name), 0, 0, 3/4*length(string(o.name)), + font=:bold, color=:white, strokecolor=:black, strokewidth=1.0 + ) + ) + ), + axis_options = Dict{Symbol, Any}( + :aspect => DataAspect(), + # :autolimitaspect => 1, + :xautolimitmargin => (0.0, 0.0), + # :yautolimitmargin => (0.0, 0.0), + :limits => (0.0, nothing, 0.0, nothing), + :hidedecorations => true + ), + state_options = Dict{Symbol, Any}( + :show_location_labels => false, + :show_movable_labels => false, + :show_edge_labels => false, + :show_location_graphics => true, + :show_movable_graphics => true + ), + graph_options = Dict{Symbol, Any}( + :node_size => 0.0, + :node_attr => (markerspace=:data,), + :nlabels_fontsize => 20, + :nlabels_align => (:center, :center), + :elabels_fontsize => 16, + ) +) + +# Render initial state +canvas = renderer(domain, state) + +# Render animation +plan = @pddl( + "(unstack f e)", "(put-down f)", + "(unstack e b)", "(put-down e)", + "(unstack d a)", "(stack d e)", + "(unstack a c)", "(stack a f)", + "(pick-up c)", "(stack c d)", + "(pick-up b)", "(stack b c)", + "(unstack a f)", "(stack a b)" +) +anim = anim_plan!(canvas, renderer, domain, state, plan, framerate=2) +save("blocksworld.mp4", anim) diff --git a/test/graphworld/test.jl b/test/graphworld/test.jl index 699e85e..4164c02 100644 --- a/test/graphworld/test.jl +++ b/test/graphworld/test.jl @@ -1,77 +1,7 @@ -# Test gridworld rendering -using PDDLViz, GLMakie, GraphMakie -using PDDL, SymbolicPlanners, PlanningDomains +@testset "blocksworld" begin + include("blocksworld.jl") +end -# Load example graph-based domain and problem -domain = load_domain(:zeno_travel) -problem = load_problem(:zeno_travel, 3) - -# Construct initial state from domain and problem -state = initstate(domain, problem) - -# Construct graphworld renderer -cmap = PDDLViz.colorschemes[:vibrant] -renderer = PDDLViz.GraphworldRenderer( - graph_layout=GraphMakie.NetworkLayout.Stress(), - location_types = [:city], - movable_types = [:movable], - edge_fn = (d, s, a, b) -> a != b, - edge_label_fn = (d, s, a, b) -> string(s[Compound(:distance, [a, b])]), - at_loc_fn = (d, s, x, loc) -> s[Compound(:at, [x, loc])], - loc_renderers = Dict{Symbol, Function}( - :city => (d, s, loc) -> CityGraphic( - 0, 0, 0.25, color=cmap[parse(Int, string(loc.name)[end])+1] - ) - ), - obj_renderers = Dict{Symbol, Function}( - :person => (d, s, o) -> HumanGraphic( - 0, 0, 0.15, color=cmap[parse(Int, string(o.name)[end])] - ), - :aircraft => (d, s, o) -> MultiGraphic( - MarkerGraphic( - '✈', 0, 0, 0.2, color=cmap[parse(Int, string(o.name)[end])] - ), - HumanGraphic( - 0, 0, 0.1, color=:black, - visible=satisfy(d, s, Compound(:in, [Var(:X), o])) - ) - ), - ), - state_options = Dict{Symbol, Any}( - :show_location_labels => true, - :show_movable_labels => true, - :show_edge_labels => true, - :show_location_graphics => true, - :show_movable_graphics => true, - :label_offset_mult => 0.25, - :movable_node_color => (:black, 0.0), - ), - graph_options = Dict{Symbol, Any}( - :node_size => 0.03, - :node_attr => (markerspace=:data,), - :nlabels_fontsize => 20, - :nlabels_align => (:center, :center), - :elabels_fontsize => 16, - ) -) - -# Render initial state -canvas = renderer(domain, state) - -# Render animation -plan = @pddl("(refuel plane1)", "(fly plane1 city0 city2)", - "(board person1 plane1 city2)", "(fly plane1 city2 city1)", - "(debark person1 plane1 city1)", "(fly plane1 city1 city2)") -renderer.state_options[:show_edge_labels] = false -anim = anim_plan(renderer, domain, state, plan, framerate=1) -save("zeno_travel.mp4", anim) - -# Convert animation frames to storyboard -storyboard = render_storyboard( - anim, [1, 3, 4, 5, 6, 7], figscale=0.65, n_rows=2, - xlabels=["t=1", "t=3", "t=4", "t=5", "t=6", "t=7"], - subtitles=["(i) Initial state", "(ii) Plane flies to city 2", - "(iii) Person 1 boards plane", "(iv) Plane flies to city 1", - "(v) Person 1 debarks plane", "(vi) Plane flies back to city 2"], - xlabelsize=18, subtitlesize=22 -) +@testset "zeno-travel" begin + include("zeno_travel.jl") +end diff --git a/test/graphworld/zeno_travel.jl b/test/graphworld/zeno_travel.jl new file mode 100644 index 0000000..f98ca17 --- /dev/null +++ b/test/graphworld/zeno_travel.jl @@ -0,0 +1,82 @@ +using PDDLViz, GLMakie, GraphMakie +using PDDL, SymbolicPlanners, PlanningDomains + +# Load example graph-based domain and problem +domain = load_domain(:zeno_travel) +problem = load_problem(:zeno_travel, 3) + +# Construct initial state from domain and problem +state = initstate(domain, problem) + +# Construct graphworld renderer +cmap = PDDLViz.colorschemes[:vibrant] +renderer = PDDLViz.GraphworldRenderer( + has_mov_edges = true, + location_types = [:city], + movable_types = [:movable], + loc_edge_fn = (d, s, a, b) -> a != b, + loc_edge_label_fn = (d, s, a, b) -> string(s[Compound(:distance, [a, b])]), + mov_loc_edge_fn = (d, s, x, loc) -> s[Compound(:at, [x, loc])], + mov_edge_fn = (d, s, x, y) -> begin + terms = [Compound(:person, Term[x]), Compound(:aircraft, Term[y]), + Compound(:in, Term[x, y])] + return satisfy(d, s, terms) + end, + loc_type_renderers = Dict{Symbol, Function}( + :city => (d, s, loc) -> CityGraphic( + 0, 0, 0.25, color=cmap[parse(Int, string(loc.name)[end])+1] + ) + ), + mov_type_renderers = Dict{Symbol, Function}( + :person => (d, s, o) -> HumanGraphic( + 0, 0, 0.15, color=cmap[parse(Int, string(o.name)[end])] + ), + :aircraft => (d, s, o) -> MarkerGraphic( + '✈', 0, 0, 0.2, color=cmap[parse(Int, string(o.name)[end])] + ) + ), + state_options = Dict{Symbol, Any}( + :show_location_labels => true, + :show_movable_labels => true, + :show_edge_labels => true, + :show_location_graphics => true, + :show_movable_graphics => true, + :label_offset => 0.15, + :movable_node_color => (:black, 0.0), + ), + axis_options = Dict{Symbol, Any}( + :aspect => 1, + :autolimitaspect => 1, + :xautolimitmargin => (0.2, 0.2), + :yautolimitmargin => (0.2, 0.2), + :hidedecorations => true + ), + graph_options = Dict{Symbol, Any}( + :node_size => 0.03, + :node_attr => (markerspace=:data,), + :nlabels_fontsize => 20, + :nlabels_align => (:center, :center), + :elabels_fontsize => 16, + ) +) + +# Render initial state +canvas = renderer(domain, state) + +# Render animation +plan = @pddl("(refuel plane1)", "(fly plane1 city0 city2)", + "(board person1 plane1 city2)", "(fly plane1 city2 city1)", + "(debark person1 plane1 city1)", "(fly plane1 city1 city2)") +renderer.state_options[:show_edge_labels] = false +anim = anim_plan!(canvas, renderer, domain, state, plan, framerate=1) +save("zeno_travel.mp4", anim) + +# Convert animation frames to storyboard +storyboard = render_storyboard( + anim, [1, 3, 4, 5, 6, 7], figscale=0.65, n_rows=2, + xlabels=["t=1", "t=3", "t=4", "t=5", "t=6", "t=7"], + subtitles=["(i) Initial state", "(ii) Plane flies to city 2", + "(iii) Person 1 boards plane", "(iv) Plane flies to city 1", + "(v) Person 1 debarks plane", "(vi) Plane flies back to city 2"], + xlabelsize=18, subtitlesize=22 +) diff --git a/test/gridworld/doors_keys_gems.jl b/test/gridworld/doors_keys_gems.jl new file mode 100644 index 0000000..6238d94 --- /dev/null +++ b/test/gridworld/doors_keys_gems.jl @@ -0,0 +1,110 @@ +# Test gridworld rendering +using PDDLViz, GLMakie +using PDDL, SymbolicPlanners, PlanningDomains + +# Load example gridworld domain and problem +domain = load_domain(:doors_keys_gems) +problem = load_problem(:doors_keys_gems, 3) + +# Load array extension to PDDL +PDDL.Arrays.register!() + +# Construct initial state from domain and problem +state = initstate(domain, problem) + +# Construct gridworld renderer +gem_colors = PDDLViz.colorschemes[:vibrant] +renderer = PDDLViz.GridworldRenderer( + resolution = (600, 700), + agent_renderer = (d, s) -> HumanGraphic(color=:black), + obj_renderers = Dict( + :key => (d, s, o) -> KeyGraphic( + visible=!s[Compound(:has, [o])] + ), + :door => (d, s, o) -> LockedDoorGraphic( + visible=s[Compound(:locked, [o])] + ), + :gem => (d, s, o) -> GemGraphic( + visible=!s[Compound(:has, [o])], + color=gem_colors[parse(Int, string(o.name)[end])] + ) + ), + show_inventory = true, + inventory_fns = [(d, s, o) -> s[Compound(:has, [o])]], + inventory_types = [:item] +) + +# Render initial state +canvas = renderer(domain, state) + +# Render plan +plan = @pddl("(right)", "(right)", "(right)", "(up)", "(up)") +renderer(canvas, domain, state, plan) + +# Render trajectory +trajectory = PDDL.simulate(domain, state, plan) +canvas = renderer(domain, trajectory) + +# Render path search solution +astar = AStarPlanner(GoalCountHeuristic(), save_search=true, + save_search_order=true, max_nodes=100) +sol = astar(domain, state, pddl"(has gem2)") +canvas = renderer(domain, state, sol) + +# Render policy solution +heuristic = PlannerHeuristic(AStarPlanner(GoalCountHeuristic(), max_nodes=20)) +rtdp = RTDP(heuristic=heuristic, n_rollouts=5, max_depth=20) +policy = rtdp(domain, state, pddl"(has gem1)") +canvas = renderer(domain, state, policy) + +# Animate plan +plan = collect(sol) +anim = anim_plan(renderer, domain, state, plan; trail_length=10) +save("doors_keys_gems.mp4", anim) + +# Animate path search planning +canvas = renderer(domain, state) +sol_anim, sol = anim_solve!(canvas, renderer, astar, + domain, state, pddl"(has gem1)") +save("doors_keys_gems_astar.mp4", sol_anim) + +# Animate RTDP planning +canvas = renderer(domain, state) +sol_anim, sol = anim_solve!(canvas, renderer, rtdp, + domain, state, pddl"(has gem2)") +save("doors_keys_gems_rtdp.mp4", sol_anim) + +# Animate RTHS planning +rths = RTHS(GoalCountHeuristic(), n_iters=5, max_nodes=15) +canvas = renderer(domain, state) +sol_anim, sol = anim_solve!(canvas, renderer, rths, + domain, state, pddl"(has gem1)") +save("doors_keys_gems_rths.mp4", sol_anim) + +# Convert animation frames to storyboard +storyboard = render_storyboard( + anim, [1, 14, 17, 24], figscale=0.75, + xlabels=["t=1", "t=14", "t=17", "t=24"], + subtitles=["(i) Initial state", "(ii) Agent picks up key", + "(iii) Agent unlocks door", "(iv) Agent picks up gem"], + xlabelsize=18, subtitlesize=22 +) + +# Construct multiple canvases on the same figure +figure = Figure(resolution=(1200, 700)) +canvas1 = new_canvas(renderer, figure[1, 1]) +canvas2 = new_canvas(renderer, figure[1, 2]) +renderer(canvas1, domain, state) +renderer(canvas2, domain, state, plan) + +# Add controller +canvas = renderer(domain, state) +controller = KeyboardController( + Keyboard.up => pddl"(up)", + Keyboard.down => pddl"(down)", + Keyboard.left => pddl"(left)", + Keyboard.right => pddl"(right)", + Keyboard.z, Keyboard.x, Keyboard.c, Keyboard.v +) +add_controller!(canvas, controller, domain, state; show_controls=true) +remove_controller!(canvas, controller) diff --git a/test/gridworld/test.jl b/test/gridworld/test.jl index 6238d94..659e6a9 100644 --- a/test/gridworld/test.jl +++ b/test/gridworld/test.jl @@ -1,110 +1,3 @@ -# Test gridworld rendering -using PDDLViz, GLMakie -using PDDL, SymbolicPlanners, PlanningDomains - -# Load example gridworld domain and problem -domain = load_domain(:doors_keys_gems) -problem = load_problem(:doors_keys_gems, 3) - -# Load array extension to PDDL -PDDL.Arrays.register!() - -# Construct initial state from domain and problem -state = initstate(domain, problem) - -# Construct gridworld renderer -gem_colors = PDDLViz.colorschemes[:vibrant] -renderer = PDDLViz.GridworldRenderer( - resolution = (600, 700), - agent_renderer = (d, s) -> HumanGraphic(color=:black), - obj_renderers = Dict( - :key => (d, s, o) -> KeyGraphic( - visible=!s[Compound(:has, [o])] - ), - :door => (d, s, o) -> LockedDoorGraphic( - visible=s[Compound(:locked, [o])] - ), - :gem => (d, s, o) -> GemGraphic( - visible=!s[Compound(:has, [o])], - color=gem_colors[parse(Int, string(o.name)[end])] - ) - ), - show_inventory = true, - inventory_fns = [(d, s, o) -> s[Compound(:has, [o])]], - inventory_types = [:item] -) - -# Render initial state -canvas = renderer(domain, state) - -# Render plan -plan = @pddl("(right)", "(right)", "(right)", "(up)", "(up)") -renderer(canvas, domain, state, plan) - -# Render trajectory -trajectory = PDDL.simulate(domain, state, plan) -canvas = renderer(domain, trajectory) - -# Render path search solution -astar = AStarPlanner(GoalCountHeuristic(), save_search=true, - save_search_order=true, max_nodes=100) -sol = astar(domain, state, pddl"(has gem2)") -canvas = renderer(domain, state, sol) - -# Render policy solution -heuristic = PlannerHeuristic(AStarPlanner(GoalCountHeuristic(), max_nodes=20)) -rtdp = RTDP(heuristic=heuristic, n_rollouts=5, max_depth=20) -policy = rtdp(domain, state, pddl"(has gem1)") -canvas = renderer(domain, state, policy) - -# Animate plan -plan = collect(sol) -anim = anim_plan(renderer, domain, state, plan; trail_length=10) -save("doors_keys_gems.mp4", anim) - -# Animate path search planning -canvas = renderer(domain, state) -sol_anim, sol = anim_solve!(canvas, renderer, astar, - domain, state, pddl"(has gem1)") -save("doors_keys_gems_astar.mp4", sol_anim) - -# Animate RTDP planning -canvas = renderer(domain, state) -sol_anim, sol = anim_solve!(canvas, renderer, rtdp, - domain, state, pddl"(has gem2)") -save("doors_keys_gems_rtdp.mp4", sol_anim) - -# Animate RTHS planning -rths = RTHS(GoalCountHeuristic(), n_iters=5, max_nodes=15) -canvas = renderer(domain, state) -sol_anim, sol = anim_solve!(canvas, renderer, rths, - domain, state, pddl"(has gem1)") -save("doors_keys_gems_rths.mp4", sol_anim) - -# Convert animation frames to storyboard -storyboard = render_storyboard( - anim, [1, 14, 17, 24], figscale=0.75, - xlabels=["t=1", "t=14", "t=17", "t=24"], - subtitles=["(i) Initial state", "(ii) Agent picks up key", - "(iii) Agent unlocks door", "(iv) Agent picks up gem"], - xlabelsize=18, subtitlesize=22 -) - -# Construct multiple canvases on the same figure -figure = Figure(resolution=(1200, 700)) -canvas1 = new_canvas(renderer, figure[1, 1]) -canvas2 = new_canvas(renderer, figure[1, 2]) -renderer(canvas1, domain, state) -renderer(canvas2, domain, state, plan) - -# Add controller -canvas = renderer(domain, state) -controller = KeyboardController( - Keyboard.up => pddl"(up)", - Keyboard.down => pddl"(down)", - Keyboard.left => pddl"(left)", - Keyboard.right => pddl"(right)", - Keyboard.z, Keyboard.x, Keyboard.c, Keyboard.v -) -add_controller!(canvas, controller, domain, state; show_controls=true) -remove_controller!(canvas, controller) +@testset "doors-keys-gems" begin + include("doors_keys_gems.jl") +end