From 0b29623176acc064fbfe19f85dc2353b992ceded Mon Sep 17 00:00:00 2001 From: Xuan Date: Tue, 11 Jul 2023 09:37:54 -0400 Subject: [PATCH] Support multi-agent gridworld policy rendering. --- src/renderers/gridworld/anim_rtdp.jl | 5 +- src/renderers/gridworld/anim_rths.jl | 22 +-- src/renderers/gridworld/policy.jl | 229 +++++++++++++++++++-------- 3 files changed, 182 insertions(+), 74 deletions(-) diff --git a/src/renderers/gridworld/anim_rtdp.jl b/src/renderers/gridworld/anim_rtdp.jl index 89f62b3..6368f7a 100644 --- a/src/renderers/gridworld/anim_rtdp.jl +++ b/src/renderers/gridworld/anim_rtdp.jl @@ -113,8 +113,9 @@ function (cb::AnimSolveCallback{GridworldRenderer})( prev_state = state end # Render / update value heatmap - if renderer.has_agent && !haskey(canvas.plots, :policy_values) - render_sol!(canvas, renderer, domain, state_obs, sol_obs; options...) + if (renderer.has_agent && !haskey(canvas.plots, :agent_policy_values)) || + (!isempty(objects) && !haskey(canvas.plots, Symbol("$(objects[1])_policy_values"))) + render_sol!(canvas, renderer, domain, state_obs, sol_obs; options...) else state_obs.val = visited[1] sol_obs[] = sol diff --git a/src/renderers/gridworld/anim_rths.jl b/src/renderers/gridworld/anim_rths.jl index c2d25c9..80513ac 100644 --- a/src/renderers/gridworld/anim_rths.jl +++ b/src/renderers/gridworld/anim_rths.jl @@ -58,11 +58,12 @@ function (cb::AnimSolveCallback{GridworldRenderer})( scatter!(ax, agent_loc, color=agent_color, markerspace=:data, marker=loc_marker, markersize=loc_markersize) end - if !isempty(objects) && !haskey(canvas.plots, :rths_obj_locs) - canvas.plots[:rths_obj_locs] = scatter!( - ax, obj_locs, color=obj_colors, markerspace=:data, - marker=loc_marker, markersize=loc_markersize - ) + for (obj, loc, col) in zip(objects, obj_locs, obj_colors) + if !haskey(canvas.plots, Symbol("rths_$(obj)_loc")) + canvas.plots[Symbol("rths_$(obj)_loc")] = + scatter!(ax, loc, color=col, markerspace=:data, + marker=loc_marker, markersize=loc_markersize) + end end # Update location observables if renderer.has_agent @@ -73,9 +74,11 @@ function (cb::AnimSolveCallback{GridworldRenderer})( end # Reset search locations if iteration has completed if isnothing(act) - empty!(search_agent_locs[]) - empty!(search_agent_dirs[]) - notify(search_agent_locs) + if renderer.has_agent + empty!(search_agent_locs[]) + empty!(search_agent_dirs[]) + notify(search_agent_locs) + end for (ls, ds) in zip(search_obj_locs, search_obj_dirs) empty!(ls[]) empty!(ds[]) @@ -83,7 +86,8 @@ function (cb::AnimSolveCallback{GridworldRenderer})( end end # Render / update value heatmap - if renderer.has_agent && !haskey(canvas.plots, :policy_values) + if (renderer.has_agent && !haskey(canvas.plots, :agent_policy_values)) || + (!isempty(objects) && !haskey(canvas.plots, Symbol("$(objects[1])_policy_values"))) render_sol!(canvas, renderer, domain, state_obs, sol_obs; options...) else state_obs.val = cur_state diff --git a/src/renderers/gridworld/policy.jl b/src/renderers/gridworld/policy.jl index 310cab1..8e41653 100644 --- a/src/renderers/gridworld/policy.jl +++ b/src/renderers/gridworld/policy.jl @@ -11,99 +11,202 @@ function render_sol!( ax = canvas.blocks[1] # Update options options = merge(renderer.trajectory_options, options) - max_states = get(options, :max_states, 200) + max_states = get(options, :max_policy_states, 200) arrowmarker = get(options, :track_arrowmarker, '▶') stopmarker = get(options, :track_stopmarker, '⦿') # Set up observables for agent - agent_locs = Observable(Point2f[]) - agent_markers = Observable(Char[]) - agent_rotations = Observable(Float64[]) - agent_values = Observable(Float64[]) + if renderer.has_agent + agent_locs = Observable(Point2f[]) + agent_values = Observable(Float64[]) + agent_markers = Observable(Char[]) + agent_rotations = Observable(Float64[]) + end + # Set up observables for tracked objects + objects = get(options, :tracked_objects, Const[]) + types = get(options, :tracked_types, Symbol[]) + for ty in types + objs = PDDL.get_objects(domain, state, ty) + append!(objects, objs) + end + obj_locs = [Observable(Point2f[]) for _ in 1:length(objects)] + obj_values = [Observable(Float64[]) for _ in 1:length(objects)] # Update observables for reachable states onany(sol, state) do sol, init_state - # Clear previous values + # Update agent observables if renderer.has_agent + # Clear previous values empty!(agent_locs[]) empty!(agent_markers[]) empty!(agent_rotations[]) empty!(agent_values[]) - end - # Iterate over reachable states up to limit - queue = [init_state] - visited = Set{UInt}() - while !isempty(queue) && length(visited) < max_states - state = popfirst!(queue) - state_id = hash(state) - state_id in visited && continue - push!(visited, state_id) - # Get state value and best action - val = SymbolicPlanners.get_value(sol, state) - best_act = SymbolicPlanners.best_action(sol, state) - # Get agent location - renderer.has_agent || continue - height = size(state[renderer.grid_fluents[1]], 1) - loc = Point2f(gw_agent_loc(renderer, state, height)) - # Terminate if location has already been encountered - loc in agent_locs[] && continue - # Update agent observables - push!(agent_locs[], loc) - next_state = transition(domain, state, best_act) - next_loc = Point2f(gw_agent_loc(renderer, next_state, height)) - marker = loc == next_loc ? stopmarker : arrowmarker - push!(agent_markers[], marker) - rotation = atan(next_loc[2] - loc[2], next_loc[1] - loc[1]) - push!(agent_rotations[], rotation) - push!(agent_values[], val) - # Add next states to queue - push!(queue, next_state) - for act in available(domain, state) - next_state = transition(domain, state, act) + # Iterate over reachable agent locations up to limit + queue = [init_state] + visited = Set{UInt}() + while !isempty(queue) && length(visited) < max_states + state = popfirst!(queue) + state_id = hash(state) + state_id in visited && continue + push!(visited, state_id) + # Get state value and best action + val = SymbolicPlanners.get_value(sol, state) + best_act = SymbolicPlanners.best_action(sol, state) + # Get agent location + height = size(state[renderer.grid_fluents[1]], 1) + loc = Point2f(gw_agent_loc(renderer, state, height)) + # Terminate if location has already been encountered + loc in agent_locs[] && continue + # Append agent location and value, etc. + push!(agent_locs[], loc) + next_state = transition(domain, state, best_act) + next_loc = Point2f(gw_agent_loc(renderer, next_state, height)) + marker = loc == next_loc ? stopmarker : arrowmarker + push!(agent_markers[], marker) + rotation = atan(next_loc[2] - loc[2], next_loc[1] - loc[1]) + push!(agent_rotations[], rotation) + push!(agent_values[], val) + # Add next states to queue push!(queue, next_state) + for act in available(domain, state) + next_state = transition(domain, state, act) + push!(queue, next_state) + end end - end - # Trigger updates - if renderer.has_agent + # Trigger updates notify(agent_locs) notify(agent_markers) notify(agent_rotations) notify(agent_values) end + # Update observables for tracked objects + for (obj, locs, vals) in zip(objects, obj_locs, obj_values) + # Clear previous values + empty!(locs[]) + empty!(vals[]) + # Add initial location and value + push!(locs[], Point2f(gw_obj_loc(renderer, init_state, obj))) + push!(vals[], SymbolicPlanners.get_value(sol, init_state)) + # Add locations and values of neighboring states + for act in available(domain, init_state) + next_state = transition(domain, init_state, act) + next_loc = Point2f(gw_obj_loc(renderer, next_state, obj)) + next_loc in locs[] && continue + push!(locs[], next_loc) + push!(vals[], SymbolicPlanners.get_value(sol, next_state)) + end + # Trigger updates + notify(locs) + notify(vals) + end end notify(sol) - # Render policy information - if renderer.has_agent - # Render state value heatmap - if get(options, :show_value_heatmap, true) - cmap = get(options, :value_colormap) do - cgrad(Makie.ColorSchemes.viridis, alpha=0.5) - end - marker = Polygon(Point2f.([(-.5, -.5), (-.5, .5), - (.5, .5), (.5, -.5)])) + # Render state value heatmap + if get(options, :show_value_heatmap, true) + cmap = get(options, :value_colormap) do + cgrad(Makie.ColorSchemes.viridis, alpha=0.5) + end + if renderer.has_agent + marker = _policy_heatmap_marker() plt = scatter!(ax, agent_locs, color=agent_values, colormap=cmap, - marker=marker, markerspace=:data, markersize=1.0) + marker=marker, markerspace=:data, markersize=1.0) Makie.translate!(plt, 0.0, 0.0, -0.5) - canvas.plots[:policy_values] = plt + canvas.plots[:agent_policy_values] = plt end - # Render best actions at each location - if get(options, :show_actions, true) - markersize = get(options, :track_markersize, 0.3) - color = get(options, :agent_color, :black) - plt = scatter!(ax, agent_locs, marker=agent_markers, - rotations=agent_rotations, markersize=markersize, - color=color, markerspace=:data) - canvas.plots[:policy_actions] = plt + for (i, obj) in enumerate(objects) + marker = _policy_heatmap_marker(length(objects), i) + locs, vals = obj_locs[i], obj_values[i] + plt = scatter!(ax, locs, color=vals, colormap=cmap, + marker=marker, markerspace=:data, markersize=1.0) + Makie.translate!(plt, 0.0, 0.0, -0.5) + canvas.plots[Symbol("$(obj)_policy_values")] = plt end - # Render state value labels at each location - if get(options, :show_value_labels, true) - label_locs = @lift $agent_locs .+ Point2f(0.0, 0.25) + end + # Render best agent actions at each location + if get(options, :show_actions, true) && renderer.has_agent + markersize = get(options, :track_markersize, 0.3) + color = get(options, :agent_color, :black) + plt = scatter!(ax, agent_locs, marker=agent_markers, + rotations=agent_rotations, markersize=markersize, + color=color, markerspace=:data) + canvas.plots[:agent_policy_actions] = plt + end + # Render state value labels at each location + if get(options, :show_value_labels, true) + if renderer.has_agent + offset = _policy_label_offset() + label_locs = @lift $agent_locs .+ offset labels = @lift map($agent_values) do val @sprintf("%.1f", val) end plt = text!(ax, label_locs; text=labels, color=:black, fontsize=0.2, markerspace=:data, align=(:center, :center)) - canvas.plots[:policy_value_labels] = plt + canvas.plots[:agent_policy_labels] = plt + end + for (i, obj) in enumerate(objects) + locs, vals = obj_locs[i], obj_values[i] + label_locs = @lift $locs .+ _policy_label_offset(length(objects), i) + labels = @lift map($vals) do val + @sprintf("%.1f", val) + end + fontsize = length(objects) > 2 ? 0.15 : 0.2 + plt = text!(ax, label_locs; text=labels, color=:black, + fontsize=fontsize, markerspace=:data, + align=(:center, :center)) + canvas.plots[Symbol("$(obj)_policy_labels")] = plt end end return canvas end + +function _policy_heatmap_marker(n::Int = 1, i::Int = 1) + if n <= 1 # Square marker for single agent + return Polygon(Point2f.([(-.5, -.5), (-.5, .5), (.5, .5), (.5, -.5)])) + elseif n <= 2 # Bottom left and top right triangles for 2 agents + if i == 1 + return Polygon(Point2f.([(-.5, -.5), (-.5, .5), (.5, -.5)])) + elseif i == 2 + return Polygon(Point2f.([(.5, .5), (.5, -.5), (-.5, .5)])) + end + elseif n <= 4 # Four triangles for 4 or less agents + if i == 1 + return Polygon(Point2f.([(-.5, -.5), (-.5, .5), (0.0, 0.0)])) + elseif i == 2 + return Polygon(Point2f.([(-.5, .5), (.5, .5), (0.0, 0.0)])) + elseif i == 3 + return Polygon(Point2f.([(.5, .5), (.5, -.5), (0.0, 0.0)])) + elseif i == 4 + return Polygon(Point2f.([(.5, -.5), (-.5, -.5), (0.0, 0.0)])) + end + else # Circle marker for more than 4 agents + angle = 2*pi*i/n + x, y = 2/n*cos(angle), 2/n*sin(angle) + points = decompose(Point2f, Circle(Point2f(x, y), 1/n)) + return Polygon(points) + end +end + +function _policy_label_offset(n::Int=1, i::Int=1) + if n <= 1 + return Point2f(0.0, 0.25) + elseif n <= 2 + if i == 1 + return Point2f(-0.2, -0.2) + elseif i == 2 + return Point2f(0.2, 0.2) + end + elseif n <= 4 + if i == 1 + return Point2f(-0.3, 0.0) + elseif i == 2 + return Point2f(0.0, 0.3) + elseif i == 3 + return Point2f(0.3, 0.0) + elseif i == 4 + return Point2f(0.0, -0.3) + end + else + angle = 2*pi*i/n + x, y = 2/n*cos(angle), 2/n*sin(angle) + return Point2f(x, y) + end +end