diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a30f2c6..330935a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,16 +7,25 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: '1' arch: x64 + - name: Set up SSH + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + - name: Adding WildcardArrays + run: | + julia --project -e 'using Pkg; Pkg.add(PackageSpec(url="git@github.com:licioromao/WildcardArrays.jl.git")); Pkg.instantiate()' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 + - uses: actions/checkout@master + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v1 with: file: ./lcov.info flags: unittests - name: codecov-umbrella + name: codecov-umbrellav diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 68ca87e..3051c1a 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -24,3 +24,5 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} + + diff --git a/.gitignore b/.gitignore index 0ee3d17..432a33a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,11 @@ *.jl.cov *.jl.*.cov *.jl.mem -Manifest.toml \ No newline at end of file +Manifest.toml +scripts/ + + +test/sources/txt-files/ +test/sources/ejs/txt-files/ +test/sources/cheng-examples/txt-files/ +.*.swp diff --git a/Project.toml b/Project.toml index 1a4aa11..ee3adeb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,27 +1,29 @@ name = "POMDPFiles" uuid = "9cf5b727-2e06-5671-8c87-8c6b0f729d5d" repo = "https://github.com/JuliaPOMDP/POMDPFiles.jl" -version = "0.2.4" +version = "0.3.0" [deps] -POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPXFiles = "c6f6ee83-58c6-5336-a19f-2c76817e1af6" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +WildcardArrays = "05d16eb0-506c-4fac-9187-f7ce7e253b09" [compat] POMDPTools = "0.1.4" POMDPXFiles = "0.2" POMDPs = "0.9" Reexport = "0.2, 1" -julia = "1" +WildcardArrays = "0.1.0" +julia = "1.6" [extras] Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Downloads", "Test", "SHA"] diff --git a/README.md b/README.md index eec99ef..35a3f87 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,85 @@ # POMDPFiles +[![Build Status](https://github.com/licioromao/POMDPFiles.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/licioromao/POMDPFiles.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![codecov](https://codecov.io/gh/licioromao/POMDPFiles.jl/branch/main/graph/badge.svg?token=btTBnBTQyw)](https://codecov.io/gh/licioromao/POMDPFiles.jl) -[![Build Status](https://github.com/JuliaPOMDP/POMDPFiles.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/JuliaPOMDP/POMDPFiles.jl/actions/workflows/CI.yml/) -[![codecov](https://codecov.io/gh/JuliaPOMDP/POMDPFiles.jl/branch/master/graph/badge.svg?token=6pQE1gHKIz)](https://codecov.io/gh/JuliaPOMDP/POMDPFiles.jl) + -Writes POMDP files (https://www.pomdp.org/code/pomdp-file-spec.html) for use in [POMDPSolve.jl](https://github.com/JuliaPOMDP/POMDPSolve.jl) from [POMDPs.jl models](https://github.com/JuliaPOMDP/POMDPs.jl). +This package constitutes the interface between the [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl) package and the file format .pomdp defined in [POMDP.org](https://www.pomdp.org/code/pomdp-file-spec.html). + +## Installation + +Please use the following command to use this package +```julia +] add git@github.com:licioromao/POMDPFiles.jl.git +``` +## API + +```julia +WildcardArrayPOMDP(s::Int, a::Int, o::Int, initial_state::InitialStateParam, discount::Float64, T::WildcardArray{Float64, 3}, O::WildcardArray{Float64, 3}, R::WildcardArray{Float64, 4}) + +WildcardArrayPOMDP(filename::String) +``` + +Constructors for the `WilcardArrayPOMDP`. We allow the user to create this type either through a .POMDP file format or specifying manually the number of actions, initial distribution, and transitions using the type [WildcardArrays](git@github.com:sisl/WildcardArrays.jl.git). The API for the InitialStateParam is described below. + +```julia +SWildcardArrayPOMDP(filename::String) + +statenames(m::SWildcardArrayPOMDP) +actionnames(m::SWildcardArrayPOMDP) +obsnames(m::SWildcardArrayPOMDP) +``` + +To deal with pomdp specifications where states, actions, and observations are specified with strings, i.e., `ss = ["warm", "very-warm"], aa = ["north", "west", "east", "west"]`, one may use the SWildcardArrayPOMDP type. More details on the differences between these two types are presented below. Three methods, `statenames`, `actionnames`, and `obsnames`, are defined to retrieve the names associated with the corresponding field of an `SWildcardArrayPOMDP` type. + +> **Warning:** Functions `statenames`, `actionnames`, `obsnames` are not implemented for a `WildcardArrayPOMDP` type. + +```julia +mutable struct InitialStateParam + number::Int + type_of_distribution::String + support_of_distribution::Set{Int} + value_of_distribution::Vector{Float64} + +end +InitialStateParam(number::Int) = InitialStateParam(number, " ", Set{Int}([]), Vector{Float64}([])) +InitialStateParam() = InitialStateParam(0) +``` + +This is interface used to define the initial distribution of `WildcardArrayPOMDP`. It contains information about the number of states, support of the distribution, and a probability vector representation the initial distribution. The parameter `type_of_distribution` can either be equal to `"uniform"` or `"general distribution"`. + + +## Quick example + +In the example below we download the *paint.95.POMDP* file from [POMDP.org](https://www.pomdp.org/examples/paint.95.POMDP) and parse the content into a `SWildcardArrayPOMDP` variable type defined in this package. We then illustrate a few functionalities from *POMDPs.jl*. + +```julia +using HTTP, POMDPs, POMDPFiles + +mktempdir() do tmp_dir + url = "https://www.pomdp.org/examples/paint.95.POMDP" + tmp_file_name = joinpath(tmp_dir, "paint.95.POMDP") + HTTP.download(url, tmp_file_name) + + pomdp = SWildcardArrayPOMDP(tmp_file_name) + + states(pomdp) +end +``` +Some of the examples in [POMDP.org](https://www.pomdp.org/examples), for instance, the `mini-hall2`, specifies a POMDP without associating names with states, actions, and observations. In these cases, one may use the `WildcardArrayPOMDP` type as described in the example below. + +```julia +using HTTP, POMDPs, POMDPFiles + +mktempdir() do tmp_dir + url = "https://www.pomdp.org/examples/mini-hall2.POMDP" + tmp_file_name = joinpath(tmp_dir, "mini-hall2.POMDP") + HTTP.download(url, tmp_file_name) + + pomdp = WildcardArrayPOMDP(tmp_file_name) + + initialstate(pomdp) +end +``` + +Using `SWildcardArrayPOMDP` in the previous example would allow us to refer to states, actions, and observations by means of a $0$-based indexing. diff --git a/src/POMDPFiles.jl b/src/POMDPFiles.jl index 6ad747d..02e536b 100644 --- a/src/POMDPFiles.jl +++ b/src/POMDPFiles.jl @@ -4,13 +4,21 @@ using Reexport using POMDPs using POMDPTools using Printf -using POMDPModels: TabularPOMDP +using WildcardArrays +using WildcardArrays: WildcardArray +using LinearAlgebra @reexport using POMDPXFiles # for POMDPAlphas -import POMDPs: action, value +import POMDPs: transition, reward, discount, observation, states, stateindex, actions, actionindex, observations, obsindex, initialstate -export read_alpha, read_pomdp +export InitialStateParam +include("types.jl") + +export WildcardArrayPOMDP, SWildcardArrayPOMDP, statenames, actionnames, obsnames +include("WildcardArrayPOMDPs.jl") + +export read_alpha, read_pomdp include("reader.jl") export numericprint, symbolicprint diff --git a/src/WildcardArrayPOMDPs.jl b/src/WildcardArrayPOMDPs.jl new file mode 100644 index 0000000..cf6db43 --- /dev/null +++ b/src/WildcardArrayPOMDPs.jl @@ -0,0 +1,146 @@ +""" + WildcardArrayPOMDP is the main data structure of the package. It is used to represent a POMDP problem from a file. + + ns: number of states + na: number of actions + no: number of observations + + support_initialstate: support of the initial state distribution + initialstate_distribution: initial state distribution + + discount: discount factor + + T: transition matrix + O: observation matrix + R: reward matrix +""" +struct WildcardArrayPOMDP <: POMDP{Int, Int, Int} + ns::Int + na::Int + no::Int + + support_initialstate::Set{Int} + initialstate_distribution::Vector{Float64} + + discount::Float64 + + T::WildcardArray{Float64, 3} + O::WildcardArray{Float64, 3} + R::WildcardArray{Float64, 4} +end + +""" + Constructors for the WildcardArrayPOMDP type. +""" +WildcardArrayPOMDP(s::Int, a::Int, o::Int, initial_state::InitialStateParam, discount::Float64, T::WildcardArray{Float64, 3}, O::WildcardArray{Float64, 3}, R::WildcardArray{Float64, 4})= WildcardArrayPOMDP(s, a, o, support(initial_state), prob(initial_state), discount, T, O, R) +WildcardArrayPOMDP(filename::String) = read_pomdp(filename; output=:WildcardArrayPOMDP) + +""" + Implementing the functions required by the POMDP interface. See [POMDPs.jl](https://juliapomdp.github.io/POMDPs.jl/latest/) for more details on the interface. +""" +states(m::WildcardArrayPOMDP) = 1:m.ns +stateindex(m::WildcardArrayPOMDP, i::Int) = (i <= m.ns) ? i : error("Querying states outside the allowable range.") + +actions(m::WildcardArrayPOMDP) = 1:m.na +actionindex(m::WildcardArrayPOMDP, i::Int) = (i <= m.na) ? i : error("Querying input outside the allowable range.") + +observations(m::WildcardArrayPOMDP) = 1:m.no +obsindex(m::WildcardArrayPOMDP, i::Int) = (i <= m.no) ? i : error("Querying observations outside the allowable range.") + +function initialstate(m::WildcardArrayPOMDP) + if !isempty(m.initialstate_distribution) + return SparseCat(states(m), m.initialstate_distribution) + else + return SparseCat(states(m), 1/m.ns*ones(m.ns)) + end +end + +function transition(m::WildcardArrayPOMDP, s::Int, a::Int) + prob_val = [m.T[a,s,sp] for sp in states(m)] + return SparseCat(states(m), prob_val) +end + +function observation(m::WildcardArrayPOMDP, a::Int, sp::Int) + prob_obs = [m.O[a, sp, obs] for obs in observations(m)] + return SparseCat(observations(m), prob_obs) +end + +reward(m::WildcardArrayPOMDP, s::Int, a::Int, sp::Int, obs::Int) = m.R[a,s,sp,obs] +reward(m::WildcardArrayPOMDP, s::Int, a::Int, sp::Int) = m.R[a,s,sp,1] +reward(m::WildcardArrayPOMDP, s::Int, a::Int) = m.R[a,s,1,1] + +discount(m::WildcardArrayPOMDP) = m.discount + +# Data structure with names +""" + SWildcardArrayPOMDP is used whenever the names of the states, actions, and observations are known. It is used to represent a POMDP problem from a file. + + dic_states: dictionary with the names of the states + dic_actions: dictionary with the names of the actions + dic_obs: dictionary with the names of the observations + pomdp: WildcardArrayPOMDP structure +""" +struct SWildcardArrayPOMDP <: POMDP{String, String, String} + dic_states::Dict{String, Int} + dic_actions::Dict{String, Int} + dic_obs::Dict{String, Int} + pomdp::WildcardArrayPOMDP + + function SWildcardArrayPOMDP(dic_ss::Dict{String, Int}, dic_aa::Dict{String, Int}, dic_oo::Dict{String, Int}, pomdp::WildcardArrayPOMDP) + @assert length(dic_ss) == pomdp.ns + @assert length(dic_aa) == pomdp.na + @assert length(dic_oo) == pomdp.no + + new(dic_ss, dic_aa, dic_oo, pomdp) + end +end +SWildcardArrayPOMDP(filename::String) = read_pomdp(filename; output=:SWildcardArrayPOMDP) +""" + Implementing the functions required by the POMDP interface. See [POMDPs.jl](https://juliapomdp.github.io/POMDPs.jl/latest/) for more details on the interface. +""" +states(m::SWildcardArrayPOMDP) = states(m.pomdp) +stateindex(m::SWildcardArrayPOMDP, key::Int) = stateindex(m.pomdp, key) +statenames(m::SWildcardArrayPOMDP) = collect(keys(m.dic_states)) +function stateindex(m::SWildcardArrayPOMDP, key::String) + i = m.dic_states[key] + return stateindex(m, i) +end + +actions(m::SWildcardArrayPOMDP) = actions(m.pomdp) +actionindex(m::SWildcardArrayPOMDP, key::Int) = actionindex(m.pomdp, key) +actionnames(m::SWildcardArrayPOMDP) = collect(keys(m.dic_actions)) +function actionindex(m::SWildcardArrayPOMDP, key::String) + i = m.dic_actions[key] + return actionindex(m, i) +end + +observations(m::SWildcardArrayPOMDP) = observations(m.pomdp) +obsindex(m::SWildcardArrayPOMDP, i::Int) = obsindex(m.pomdp, i) +obsnames(m::SWildcardArrayPOMDP) = collect(keys(m.dic_obs)) +function obsindex(m::SWildcardArrayPOMDP, key::String) + i = m.dic_obs[key] + return obsindex(m, i) +end + +initialstate(m::SWildcardArrayPOMDP) = initialstate(m.pomdp) + +transition(m::SWildcardArrayPOMDP, s::Int, a::Int) = transition(m.pomdp, s, a) +function transition(m::SWildcardArrayPOMDP, s::String, a::String) + is = m.dic_states[s]; ia = m.dic_actions[a] + return transition(m, is, ia) +end + +observation(m::SWildcardArrayPOMDP, a::Int, sp::Int) = observation(m.pomdp, a, sp) +function observation(m::SWildcardArrayPOMDP, a::String, sp::String) + isp = m.dic_states[sp]; ia = m.dic_actions[a] + return observation(m, ia, isp) +end + +reward(m::SWildcardArrayPOMDP, s::Int, a::Int, sp::Int, obs::Int) = reward(m.pomdp, s, a, sp, obs) +reward(m::SWildcardArrayPOMDP, s::String, a::String, sp::String, obs::String) = reward(m.pomdp, m.dic_states[s], m.dic_actions[a], m.dic_states[sp], m.dic_obs[obs]) +reward(m::SWildcardArrayPOMDP, s::Int, a::Int, sp::Int) = reward(m.pomdp, s, a, sp) +reward(m::SWildcardArrayPOMDP, s::String, a::String, sp::String) = reward(m.pomdp, m.dic_states[s], m.dic_actions[a], m.dic_states[sp]) +reward(m::SWildcardArrayPOMDP, s::Int, a::Int) = reward(m.pomdp, s, a) +reward(m::SWildcardArrayPOMDP, s::String, a::String) = reward(m.pomdp, m.dic_states[s], m.dic_actions[a]) + +discount(m::SWildcardArrayPOMDP) = discount(m.pomdp) \ No newline at end of file diff --git a/src/reader.jl b/src/reader.jl index 2e8b0ed..a099d8b 100644 --- a/src/reader.jl +++ b/src/reader.jl @@ -1,5 +1,4 @@ const REGEX_FLOATING_POINT = r"[-+]?[0-9]*\.?[0-9]+" - """ Read a `.alpha` file as generated by pomdp-solve. Works the same was as `read_pomdp` in `POMDPXFile.jl`. @@ -43,7 +42,7 @@ function read_alpha(filename::AbstractString) alpha_vector_line_indeces = Int[] vector_length = -1 - for i in 1:length(lines) + for i in eachindex(lines) matches = collect((m.match for m = eachmatch(REGEX_FLOATING_POINT, lines[i]))) @@ -80,256 +79,274 @@ function read_alpha(filename::AbstractString) return alpha_vectors, alpha_actions end +""" + Read a .pomdp file following the specfication at http://www.pomdp.org/code/pomdp-file-spec.html and returns a FilePOMDP or SFilePOMDP object that can be used within the POMDPs.jl interface. +""" +function read_pomdp(filename::String; output::Symbol = :SWildcardArrayPOMDP) + lines = open(readlines, filename) |> remove_comments_and_white_space + + # Getting info from preamble + regex_filtering_preamble = r"\s*[RTO]\s*:" + preamble = lines[1:findfirst(startswith.(lines, regex_filtering_preamble))-1] + preamble_dict = check_preamble_fields(join(preamble, "\n")) + discount, type_reward, actions, states, observations = process_preamble(preamble_dict) + + dic_states = Dict(string(nn) => index for (index, nn) in enumerate(names(states))) # needed here to process the initial state + + # # # # Processing the initial distribution + init_state_tuple = Dict((kk,vv) for (kk, vv) in preamble_dict if kk in ["start", "start include", "start exclude"]) + sorted_keys = sort(collect(init_state_tuple), by=x->x[2].priority) + initialstate = InitialStateParam() + + for (kk,vv) in sorted_keys + initialstate_content = vv.value + types = [Float64, Int] + tmp_initialstate_content = map(x->tryparse.(x, string.(split(initialstate_content))), types) + + if isequal(kk, "start") + if isequal(initialstate_content, "uniform") + initialstate.support_of_distribution = Set([i for i in Base.OneTo(number(states))]) + initialstate.value_of_distribution = (1/number(states))*ones(number(states)) + initialstate.type_of_distribution = "uniform" + initialstate.number = number(states) -function read_pomdp(filename::AbstractString) - lines = open(readlines, filename) + else + # Either a vector suming to one + initialstate.number = number(states) + if all(x -> !isnothing(x), tmp_initialstate_content[1]) # or a vector of floats + @assert test_if_probability(tmp_initialstate_content[1]) + + # Saving content on InitialStateParam + initialstate.value_of_distribution = tmp_initialstate_content[1] + initialstate.support_of_distribution = Set(findall(x -> x > 0, initialstate.value_of_distribution)) + initialstate.type_of_distribution = "general distribution" + + elseif all(x -> !isnothing(x), tmp_initialstate_content[2]) # or a vector of integers + @assert (all(x -> x >= 1 && x <= number(states), tmp_initialstate_content[2])) + + # Saving content on InitialStateParam + initialstate.support_of_distribution = Set(tmp_initialstate_content[2]) + initialstate.value_of_distribution = vec((1/length(initialstate.support_of_distribution))*sum(Diagonal(ones(Float64, number(states)))[:, collect(initialstate.support_of_distribution)], dims=2)) + initialstate.type_of_distribution = "uniform" + + elseif all(x-> x in names(states), string.(split(initialstate_content))) # or a vector of names + # Saving content on InitialStateParam + init_state = map(x -> dic_states[x], string.(split(initialstate_content))) + initialstate.support_of_distribution = Set(init_state) + initialstate.value_of_distribution = vec((1/length(initialstate.support_of_distribution))*sum(Diagonal(ones(Float64, number(states)))[:, collect(initialstate.support_of_distribution)], dims=2)) + initialstate.type_of_distribution = "uniform" - discount = 0 - num_states = 0 - num_actions = 0 - num_observations = 0 + else + error("Unable to parse the initial condition.") + end + end + elseif isequal(kk, "start include") + if all(x -> !isnothing(x), tmp_initialstate_content[2]) # or a vector of integers + initialstate.support_of_distribution = union(initialstate.support_of_distribution, Set(tmp_initialstate_content[2])) # union the sets + elseif all(x-> x in names(states), string.(split(initialstate_content))) # or a vector of names + init_state = map(x -> dic_states[x], string.(split(initialstate_content))) + initialstate.support_of_distribution = union(initialstate.support_of_distribution, Set(init_state)) + end - states = 0 - actions = 0 - observations = 0 + initialstate.value_of_distribution = vec((1/length(initialstate.support_of_distribution))*sum(Diagonal(ones(Float64, number(states)))[:, collect(initialstate.support_of_distribution)], dims=2)) + initialstate.type_of_distribution = "uniform" + initialstate.number = number(states) - T_lines = Vector{Int64}() - O_lines = Vector{Int64}() - R_lines = Vector{Int64}() + elseif isequal(kk, "start exclude") + if all(x -> !isnothing(x), tmp_initialstate_content[2]) # or a vector of integers + initialstate.support_of_distribution = setdiff(initialstate.support_of_distribution, Set(tmp_initialstate_content[2])) + elseif all(x-> x in names(states), string.(split(initialstate_content))) # or a vector of names + init_state = map(x -> dic_states[x], string.(split(initialstate_content))) + initialstate.support_of_distribution = setdiff(initialstate.support_of_distribution, Set(init_state)) + end - lines = map(lines) do line - line[1:something(findfirst('#', line), length(line))] + if !isempty(initialstate.support_of_distribution) + initialstate.value_of_distribution = vec((1/length(initialstate.support_of_distribution))*sum(Diagonal(ones(Float64, number(states)))[:, collect(initialstate.support_of_distribution)], dims=2)) + end + initialstate.type_of_distribution = "uniform" + initialstate.number = number(states) + else + error("Unable to parse the initial condition.") + end end - for i in 1:length(lines) - if length(lines[i]) > 0 - if occursin(r"discount:", lines[i]) && lines[i][1] != '#' - discount = parse(Float64, match(REGEX_FLOATING_POINT, lines[i]).match) + sorted_fields = order_of_transition_reward_observation(lines, 1) + + files_transition = [] + files_obs = [] + files_values = [] + # Finding the chunk of the file with the transition, observation, and reward specifications + for (index, (type_of_matrix, line_number)) in enumerate(sorted_fields) + if index < length(sorted_fields) + range_spec = line_number:(sorted_fields[index+1][2] -1) + if isequal(type_of_matrix, "T") + files_transition = lines[range_spec] end - if occursin(r"states:", lines[i]) && lines[i][1] != '#' - states = split(strip(lines[i]), ' ') - if length(states) > 2 - num_states = length(states) - 1 - states = states[2:end] - else - num_states = parse(Int64, states[2]) - states = collect(string(i) for i in 0:num_states-1) - end + if isequal(type_of_matrix, "O") + files_obs = lines[range_spec] end - if occursin(r"actions:", lines[i]) && lines[i][1] != '#' - actions = split(strip(lines[i]), ' ') - if length(actions) > 2 - num_actions = length(actions) - 1 - actions = actions[2:end] - else - num_actions = parse(Int64, actions[2]) - actions = collect(string(i) for i in 0:num_actions-1) - end + if isequal(type_of_matrix, "R") + files_values = lines[range_spec] end - if occursin(r"observations:", lines[i]) && lines[i][1] != '#' - observations = split(strip(lines[i]), ' ') - if length(observations) > 2 - num_observations = length(observations) - 1 - observations = observations[2:end] - else - num_observations = parse(Int64, observations[2]) - observations = collect(string(i) for i in 0:num_observations-1) - end - end - if occursin(r"T:|T :", lines[i]) - push!(T_lines, i) + else + range_spec = (line_number:length(lines)) + if isequal(type_of_matrix, "T") + files_transition = lines[range_spec] end - if occursin(r"O:|O :", lines[i]) - push!(O_lines, i) + if isequal(type_of_matrix, "O") + files_obs = lines[range_spec] end - if occursin(r"R:|R :", lines[i]) - push!(R_lines, i) + if isequal(type_of_matrix, "R") + files_values = lines[range_spec] end end end - T = zeros(num_states, num_actions, num_states) - O = zeros(num_observations, num_actions, num_states) - R = zeros(num_states, num_actions) - - ind1 = 0 - ind2 = 0 - ind3 = 0 - - if length(T_lines) > 0 - if length(findall(x->x==':', lines[T_lines[1]])) == 3 - for t in T_lines - l = replace(lines[t], ':'=>' ') - line = split(l, ' ') - line = collect(strip(i) for i in line) - deleteat!(line, findall(x->x=="", line)) - if line[3] == "*" - ind1 = collect(1:length(states)) - else - ind1 = findall(x->x==line[3], states) - end - if line[2] == "*" - ind2 = collect(1:length(actions)) - else - ind2 = findall(x->x==line[2], actions) - end - if line[4] == "*" - ind3 = collect(1:length(states)) - else - ind3 = findall(x->x==line[4], states) - end - T[ind3, ind2, ind1] .= parse(Float64, line[5]) - end - elseif length(findall(x->x==':', lines[T_lines[1]])) == 2 - for t in T_lines - l = t+1 - act = strip(split(lines[t], ':')[2]) - st = strip(split(lines[t], ':')[3]) - i = findfirst(x->x==act, actions) - j = findfirst(x->x==st, states) - T[:,i,j] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - end - else - for t in T_lines - l = t+1 - id = findall(strip(lines[l]), "identity") - un = findall(strip(lines[l]), "uniform") - act = strip(split(lines[t], ':')[2]) - i = findfirst(x->x==act, actions) - if length(id) > 0 - for j in 1:num_states - T[j,i,j] = 1 - l += 1 - end - elseif length(un) > 0 - for j in 1:num_states - T[:,i,j] = ones(num_states)./num_states - l += 1 - end - else - for j in 1:num_states - T[:,i,j] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - l += 1 - end - end - end + # Processing observation probability + str_trans = join(files_transition, "\n") + vv = [names(actions), names(states), names(states)] + wc_trans = WildcardArray(str_trans, vv) + + # # Processing observation probability + str_obs = join(files_obs, "\n") + vv = [names(actions), names(states), names(observations)] + wc_obs = WildcardArray(str_obs, vv) + + # # Processing observation probability + str_values = join(files_values, "\n") + vv = [names(actions), names(states), names(states), names(observations)] + wc_values = WildcardArray(str_values, vv) + + pomdp_struc = WildcardArrayPOMDP(number(states), number(actions), number(observations), initialstate, discount[1], wc_trans, wc_obs, wc_values) + + if output == :WildcardArrayPOMDP + return pomdp_struc + + elseif output == :SWildcardArrayPOMDP + dic_action = Dict(string(nn) => index for (index, nn) in enumerate(names(actions))) + dic_obs = Dict(string(nn) => index for (index, nn) in enumerate(names(observations))) + + return SWildcardArrayPOMDP(dic_states, dic_action, dic_obs, pomdp_struc) + else + error("Output type invalid") + end +end + +################ Auxiliary functions ################## +""" + test_if_probability(prob::Union{Vector{Float64}, Vector{Nothing}, Nothing};rtol=1e-3) + + Built-in function that tests whether a vector is a probability distribution. It checks if the elements are between 0 and 1 and if the sum of the elements is approximately 1. The function returns true if the vector is a probability distribution and false otherwise. +""" +function test_if_probability(prob::Union{Vector{Float64}, Vector{Nothing}, Nothing};rtol=1e-3) + if isnothing(prob) || eltype(prob) == Nothing + return false + else + between_0_1 = all(x -> 0 <= x <= 1, prob) + return (between_0_1 && isapprox(sum(prob), 1; rtol=rtol)) ? true : false + end +end +""" + remove_comments_and_white_space(file::Vector{String}) is used by read_pomdp to remove comments and white spaces from the file. This function allows for some standardization of process of parsing files. +""" +function remove_comments_and_white_space(file::Vector{String}) + processed_file = [] + + for line in file + without_comments = replace(line, r"#.*" => "") |> strip + + if !isempty(without_comments) + push!(processed_file, without_comments) end end - if length(O_lines) > 0 - if length(findall(x->x==':', lines[O_lines[1]])) == 3 - for t in O_lines - l = replace(lines[t], ':'=>' ') - line = split(l, ' ') - line = collect(strip(i) for i in line) - deleteat!(line, findall(x->x=="", line)) - if line[4] == "*" - ind1 = collect(1:length(observations)) - else - ind1 = findall(x->x==line[4], observations) - end - if line[2] == "*" - ind2 = collect(1:length(actions)) - else - ind2 = findall(x->x==line[2], actions) - end - if line[3] == "*" - ind3 = collect(1:length(states)) - else - ind3 = findall(x->x==line[3], states) - end - O[ind1, ind2, ind3] .= parse(Float64, line[5]) - end - elseif length(findall(x->x==':', lines[O_lines[1]])) == 2 - for t in T_lines - l = t+1 - act = strip(split(lines[t], ':')[2]) - st = strip(split(lines[t], ':')[3]) - i = findfirst(x->x==act, actions) - j = findfirst(x->x==st, states) - O[:,i,j] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - end + return Vector{String}(filter(x -> !isempty(x), processed_file)) +end +""" + convert_to_data_structure(field::String, preamble::Dict{String,String}) is used by read_pomdp to convert the information in the preamble into an intermidiate format before passing it into a ContainerNames object. +""" +function convert_to_data_structure(field::String, preamble::Dict{String,Any}) + entry = preamble[field] + entry = replace(entry, r"\"+" => "") + + return !isnothing(tryparse(Int64, entry)) ? parse(Int64, entry) : string.(split(entry)) +end +""" + order_of_transition_reward_observation(file_lines::Vector{String}, start_line::Int64) is used by read_pomdp to find the order of the transition, reward, and observation matrices in the file. +""" +function order_of_transition_reward_observation(file_lines::Vector{String}, start_line::Int64) + key_field = ["O", "T", "R"] + regex_fields = Vector{String}() + + [push!(regex_fields, "\\s*$field\\s*:") for field in key_field] + + indices = map(x-> findfirst(startswith.(file_lines, Regex(x))), regex_fields) + + dict_scanning = Dict(field => indices[ii] for (ii, field) in enumerate(key_field)) + sorted_fields = sort(collect(pairs(dict_scanning)), by=x->x[2]) + + return sorted_fields +end + +######### Auxiliary functions -- PREAMBLE ############### +""" + check_preamble_fields(preamble::String) is used by read_pomdp to check if the preamble of the file has all the necessary fields. An error is issued if one of the fields "discount", "values", "states", "actions", or "observations" is missing. +""" +function check_preamble_fields(preamble::String) + key_fields = ["discount", "values", "states", "actions", "observations"] + preamble_vec = string.(split(preamble, "\n")) + + preamble_dict = Dict{String, Any}() + field_dict = Dict{String, Int64}() + + # Checking whether the preamble has all the necessary fields + for field in key_fields + reg_expr = Regex("\\s*$(field)\\s*:") + index = findfirst(startswith.(preamble_vec, reg_expr)) + + if !isnothing(index) + field_dict[field] = index else - for t in O_lines - l = t+1 - un = findall(strip(lines[l]), "uniform") - act = strip(split(lines[t], ':')[2]) - if act == "*" - if length(un) > 0 - for j in 1:num_states - for i in 1:num_actions - O[:,i,j] = ones(num_observations)./num_observations - end - l += 1 - end - else - for j in 1:num_states - for i in 1:num_actions - O[:,i,j] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - end - l += 1 - end - end - else - i = findfirst(x->x==act, actions) - if length(un) > 0 - for j in 1:num_states - O[:,i,j] = ones(num_observations)./num_observations - l += 1 - end - else - for j in 1:num_states - O[:,i,j] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - l += 1 - end - end - end - end + error("Missing field $(field) in the file") end end - if length(R_lines) > 0 - if length(findall(x->x==':', lines[R_lines[1]])) == 4 - for t in R_lines - l = replace(lines[t], ':'=>' ') - line = split(l, ' ') - line = collect(strip(i) for i in line) - deleteat!(line, findall(x->x=="", line)) - if line[3] == "*" - ind1 = collect(1:length(states)) - else - ind1 = findall(x->x==line[3], states) - end - if line[2] == "*" - ind2 = collect(1:length(actions)) - else - ind2 = findall(x->x==line[2], actions) - end - R[ind1, ind2] .= parse(Float64, line[6]) - end - elseif length(findall(x->x==':', lines[R_lines[1]])) == 3 - for t in R_lines - l = t+1 - act = strip(split(lines[t], ':')[2]) - i = findfirst(x->x==act, actions) - for j in 1:num_states - T[j,i,:] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - l += 1 - end - end + regex_preamble = r"\s*(.*)\s*:\s+([\d\D]*?)(?=(.*:)|$)" + + for (ii,m) in enumerate(eachmatch(regex_preamble, preamble)) + field = strip(m.captures[1], ['\n', '\r', ' ', '\"']) + content = strip(m.captures[end-1], ['\n', '\r', ' ', '\"']) + + if field in ["start", "start include", "start exclude"] + preamble_dict[field] = WildcardArrays.PriorityValue(ii, string(content)) else - for t in R_lines - l = t+1 - act = strip(split(lines[t], ':')[2]) - i = findfirst(x->x==act, actions) - for j in 1:num_states - T[j,i,:] = collect((parse(Float64, m.match) for m = eachmatch(REGEX_FLOATING_POINT, lines[l]))) - l += 1 - end - end + preamble_dict[field] = content end end - m = TabularPOMDP(T, R, O, discount) - return m + return preamble_dict end +""" + process_preamble(preamble::Dict{String, String}) is used by read_pomdp to process the preamble of the file and check if the fields "discount", "values", "states", "actions", and "observations" have the correct syntax. The output are the discount, values, actions, states, and observations parameters, where actions, states, and observations are converted into ContainerNames objects. +""" +function process_preamble(preamble::Dict{String, Any}) + # checking discount syntax => it must be a float number + discount = parse(Float64, preamble["discount"]) + if ~(0 <= discount <= 1) + error("Discount parameter must be a number between zero and one") + end + + # checking value syntax => either "reward" or "cost" + values_type = preamble["values"] + values_param = [(isequal(values_type,"reward")) || (isequal(values_type,"cost") || isequal(values_type, "rewards") || isequal(values_type, "costs")) ? values_type : error("Invalid specification for the objective function.")] + # checking actions syntax => either an integer or a collection of names + actions_param = convert_to_data_structure("actions", preamble) + # checking states syntax => either an integer or a collection of names + states_param = convert_to_data_structure("states", preamble) + # checking observation syntax => either an integer or a collection of names + observations_param = convert_to_data_structure("observations", preamble) + + + + return discount, values_param, ContainerNames(actions_param), ContainerNames(states_param), ContainerNames(observations_param) +end \ No newline at end of file diff --git a/src/types.jl b/src/types.jl new file mode 100644 index 0000000..11ae2d1 --- /dev/null +++ b/src/types.jl @@ -0,0 +1,60 @@ +## Types to deal with the memory saving feature +""" + ContainerNames + + names: Vector{String} - a vector with the names of the actions, states, or observations. + number: Int - the number of actions, states, or observations. + + This type was created to store the names and number of actions, states, and observations when parsed from the preamble of a .pomdp file format. +""" +struct ContainerNames + names::Vector{String} + number::Int +end + +""" + This constructor can be used to create a ContainerNames object when the names of the actions, states, or observations are not known. Note that by default the startindex is 0. +""" +function ContainerNames(number::Int;startindex::Int=0) + names = [string(i) for i in startindex:(number-1)] + return ContainerNames(names, number) +end +""" + We can also build a ContainerNames object by passing the names of the actions, states, or observations. +""" +ContainerNames(names_of_actions::Vector{String}) = ContainerNames(names_of_actions, length(names_of_actions)) + +Base.names(a::ContainerNames) = a.names + +""" + number returns the number of actions, states, or observations in the ContainerNames object. +""" +number(a::ContainerNames) = a.number + +""" + InitialStateParam was created to store the initial state distribution of a POMDP. The distribution can be of any type, but it is stored as a vector of Float64 values. +""" +mutable struct InitialStateParam + number::Int + type_of_distribution::String + support_of_distribution::Set{Int} + value_of_distribution::Vector{Float64} + +end +InitialStateParam(number::Int) = InitialStateParam(number, " ", Set{Int}([]), Vector{Float64}([])) +InitialStateParam() = InitialStateParam(0) + +""" + number returns the number of states in the initial state distribution. +""" +number(init::InitialStateParam) = init.size_of_states + +""" + support returns the support of the initial state distribution. +""" +support(init::InitialStateParam) = init.support_of_distribution + +""" + value returns the a vector with the initial state distribution. +""" +prob(init::InitialStateParam) = init.value_of_distribution diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..b5f82ef --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,16 @@ +[deps] +Cascadia = "54eefc05-d75b-58de-a785-1a3403f0919f" +CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a" +HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +POMDPFiles = "9cf5b727-2e06-5671-8c87-8c6b0f729d5d" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WildcardArrays = "05d16eb0-506c-4fac-9187-f7ce7e253b09" + +[compat] +POMDPFiles = "0.3.0" +WildcardArrays = "0.1.0" diff --git a/test/aux_func.jl b/test/aux_func.jl new file mode 100644 index 0000000..f809ba6 --- /dev/null +++ b/test/aux_func.jl @@ -0,0 +1,184 @@ +function save_files(cleanup::Bool=true) + url_source = "http://pomdp.org/examples/" + data_source = HTTP.get(url_source) + parsed_web = data_source.body |> String + + parsed_html = parsehtml(parsed_web) + ss = Selector("a") + links = eachmatch(ss, parsed_html.root) + + tmp_dir_tar_gz = mktempdir(tempdir(); prefix="POMDPtargz_", cleanup=cleanup) + tmp_dir_gz = mktempdir(tempdir(); prefix="POMDPgz_", cleanup=cleanup) + + vec_dir_tar_gz = Vector{String}() + tmp_dir_pomdp = mktempdir(tempdir(); prefix="POMDPfiles_pomdp_", cleanup=cleanup) + tmp_dir_files_gz = mktempdir(tempdir(); prefix="POMDPfiles_gz_", cleanup=cleanup) + + for link in links + href = link.attributes["href"] + if occursin(r".tar.gz$", href) + complete_url = startswith(href, "http") ? href : url_source * href + + println(complete_url) + regex_getnames = r"/([^/]*?)\.tar\.gz$" + + tmp_file_name = tmp_dir_tar_gz * "/" * match(regex_getnames, complete_url).captures[1] * ".tar.gz" + + HTTP.download(complete_url, tmp_file_name) + println("Saved tar.gz files in ", tmp_file_name, "\n\n\n") + + tmp_files_tar_gz = mktempdir(tempdir(); prefix="POMDPfiles_tar_gz_", cleanup=cleanup) + push!(vec_dir_tar_gz, tmp_files_tar_gz) + println("Extracting files in ", tmp_files_tar_gz, "\n\n\n") + + open(tmp_file_name, "r") do gz_file + decomp_gz_file = GzipDecompressorStream(gz_file) + + Tar.extract(decomp_gz_file, tmp_files_tar_gz) + close(decomp_gz_file) + end + end + + if occursin(r".POMDP$", href) + complete_url = startswith(href, "http") ? href : url_source * href + + regex_getnames = r"/([^/]*?)\.[Pp][Oo][Mm][Dd][Pp]$" + temp_file_name = tmp_dir_pomdp * "/" * match(regex_getnames, complete_url).captures[1] * ".POMDP" + HTTP.download(complete_url, temp_file_name) + + println("Saved file: ", temp_file_name, " in ", tmp_dir_pomdp, "\n\n\n") + + end + + if occursin(r"(? T_values_mit[i] for i in eachindex(T_tuple_mit)) + + O_tuple_mit = [(1,7,5), (2, 75, 17), (4, 203, 8)] + O_values_mit = [0.121500, 0.056700, 0.002250] + O_dict_mit = Dict(O_tuple_mit[i] => O_values_mit[i] for i in eachindex(O_tuple_mit)) + + # 4 semicolon with wildcards + R_tuple_mit = [(4, 185, 2, 3), (4, 185, 1, 1), (4, 169, 1, 1)] + R_values_mit = [-1, -1, 1] + R_dict_mit = Dict(R_tuple_mit[i] => R_values_mit[i] for i in eachindex(R_tuple_mit)) + + mit_dict = Dict("T" => T_dict_mit, "O" => O_dict_mit, "R" => R_dict_mit) + +# hallway.POMDP tests set-up + T_tuple_hallway = [(1,1,1), (3, 1, 2)] + T_values_hallway = [1, 0.7] + T_dict_hallway = Dict(T_tuple_hallway[i] => T_values_hallway[i] for i in eachindex(T_tuple_hallway)) + + # 3 semicolon with vector of probabilities + O_tuple_hallway = [(3,1,4), (5,28,1), (4, 57, 21), (2, 60, 1)] + O_values_hallway = [0.076949, 0.085737, 1, 0] + O_dict_hallway = Dict(O_tuple_hallway[i] => O_values_hallway[i] for i in eachindex(O_tuple_hallway)) + + # 4 semicolon with wildcards + R_tuple_hallway = [(1,1,1,1), (4, 30, 57, 20)] + R_values_hallway = [0, 1] + R_dict_hallway = Dict(R_tuple_hallway[i] => R_values_hallway[i] for i in eachindex(R_tuple_hallway)) + hallway_dict = Dict("T" => T_dict_hallway, "O" => O_dict_hallway, "R" => R_dict_hallway) + +# bulkhead_A.POMDP tests set-up + # Values given by transition probabilities and wild cards + T_tuple_bulkhead_A = [(4,1,4), (4, 4, 1), (4,5,8), (6,1,10), (1,2,2), (1,3,3), (2,1,2)] + T_values_bulkhead_A = [0.97, 0, 0.98, 1, 1, 1, 0] + T_dict_bulkhead_A = Dict(T_tuple_bulkhead_A[i] => T_values_bulkhead_A[i] for i in eachindex(T_tuple_bulkhead_A)) + + # Values given by transition probabilities and wild cards + O_tuple_bulkhead_A = [(3,1,1), (3,9,4), (3,9,6), (5,1,1)] + O_values_bulkhead_A = [0,0.25,0.75, 1] + O_dict_bulkhead_A = Dict(O_tuple_bulkhead_A[i] => O_values_bulkhead_A[i] for i in eachindex(O_tuple_bulkhead_A)) + + # 4 semicolon with wildcards + R_tuple_bulkhead_A = [(5,7,2,1), (5,1,4,5)] + R_values_bulkhead_A = [45000, -15000] + R_dict_bulkhead_A = Dict(R_tuple_bulkhead_A[i] => R_values_bulkhead_A[i] for i in eachindex(R_tuple_bulkhead_A)) + bulkhead_A_dict = Dict("T" => T_dict_bulkhead_A, "O" => O_dict_bulkhead_A, "R" => R_dict_bulkhead_A) + +# baseball.POMDP: only due to ir being a large file. The way transitions are specified are "boring" and have been extensively tested with the previous examples + T_tuple_baseball = [(2,7301,7681), (6,7681,7681)] + T_values_baseball = [0.9, 1] + T_dict_baseball = Dict(T_tuple_baseball[i] => T_values_baseball[i] for i in eachindex(T_tuple_baseball)) + + O_tuple_baseball = [(1,218,4), (3,4017,5)] + O_values_baseball = [1, 0] + O_dict_baseball = Dict(O_tuple_baseball[i] => O_values_baseball[i] for i in eachindex(O_tuple_baseball)) + + R_tuple_baseball = [(6,6144,1,1), (6,6145,4,5), (6,6145,1,1)] + R_values_baseball = [3,4,4] + R_dict_baseball = Dict(R_tuple_baseball[i] => R_values_baseball[i] for i in eachindex(R_tuple_baseball)) + baseball_dict = Dict("T" => T_dict_baseball, "O" => O_dict_baseball, "R" => R_dict_baseball) + +# tiger_95.POMDP: testing all transitions, observations and rewards + T_tuple_tiger_95 = [(1,1,1), (1,1,2), (1,2,1),(1,2,2), (2,1,1), (2,1,2), (2,2,1), (2,2,2), (3,1,1), (3,1,2), (3,2,1), (3,2,2)] + T_values_tiger_95 = [1,0,0,1,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] + T_dict_tiger_95 = Dict(T_tuple_tiger_95[i] => T_values_tiger_95[i] for i in eachindex(T_tuple_tiger_95)) + + O_tuple_tiger_95 = [(1,1,1), (1,1,2), (1,2,1),(1,2,2), (2,1,1), (2,1,2), (2,2,1), (2,2,2), (3,1,1), (3,1,2), (3,2,1), (3,2,2)] + O_values_tiger_95 = [0.85,0.15,0.15,0.85,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] + O_dict_tiger_95 = Dict(O_tuple_tiger_95[i] => O_values_tiger_95[i] for i in eachindex(O_tuple_tiger_95)) + + + R_tuple_tiger_95 = [(1,1,1,1),(1,2,1,1),(2,1,1,1),(2,2,1,1),(3,1,1,1),(3,2,1,1), (1,1,2,1)] + R_values_tiger_95 = [-1,-1,-100,10,10,-100,-1] + R_dict_tiger_95 = Dict(R_tuple_tiger_95[i] => R_values_tiger_95[i] for i in eachindex(R_tuple_tiger_95)) + tiger_95_dict = Dict("T" => T_dict_tiger_95, "O" => O_dict_tiger_95, "R" => R_dict_tiger_95) + + + name_to_dic = Dict("mit" => mit_dict, "baseball" => baseball_dict, + "tiger.95" => tiger_95_dict, "hallway" => hallway_dict, "bulkhead.A" => bulkhead_A_dict) + + return Dict(name => name_to_dic[name] for name in nn_individual_tests) +end + diff --git a/test/data/tiger.numeric.pomdp b/test/data/tiger.numeric.pomdp deleted file mode 100644 index 65dbd88..0000000 --- a/test/data/tiger.numeric.pomdp +++ /dev/null @@ -1,55 +0,0 @@ -discount: 0.95 -values: reward -states: 2 -actions: 3 -observations: 2 - -T: 0 -1.0 0.0 -0.0 1.0 - -T: 1 -0.5 0.5 -0.5 0.5 - -T: 2 -0.5 0.5 -0.5 0.5 - -O: 0 -0.85 0.15000000000000002 -0.15000000000000002 0.85 - -O: 1 -0.5 0.5 -0.5 0.5 - -O: 2 -0.5 0.5 -0.5 0.5 - -R: 0 : 0 : 0 : 0 -1.0 -R: 0 : 0 : 0 : 1 -1.0 -R: 0 : 0 : 1 : 0 -1.0 -R: 0 : 0 : 1 : 1 -1.0 -R: 0 : 1 : 0 : 0 -1.0 -R: 0 : 1 : 0 : 1 -1.0 -R: 0 : 1 : 1 : 0 -1.0 -R: 0 : 1 : 1 : 1 -1.0 -R: 1 : 0 : 0 : 0 -100.0 -R: 1 : 0 : 0 : 1 -100.0 -R: 1 : 0 : 1 : 0 -100.0 -R: 1 : 0 : 1 : 1 -100.0 -R: 1 : 1 : 0 : 0 10.0 -R: 1 : 1 : 0 : 1 10.0 -R: 1 : 1 : 1 : 0 10.0 -R: 1 : 1 : 1 : 1 10.0 -R: 2 : 0 : 0 : 0 10.0 -R: 2 : 0 : 0 : 1 10.0 -R: 2 : 0 : 1 : 0 10.0 -R: 2 : 0 : 1 : 1 10.0 -R: 2 : 1 : 0 : 0 -100.0 -R: 2 : 1 : 0 : 1 -100.0 -R: 2 : 1 : 1 : 0 -100.0 -R: 2 : 1 : 1 : 1 -100.0 - diff --git a/test/data/tiger.symbolic.pomdp b/test/data/tiger.symbolic.pomdp deleted file mode 100644 index 2f9c5de..0000000 --- a/test/data/tiger.symbolic.pomdp +++ /dev/null @@ -1,64 +0,0 @@ -discount: 0.95 -values: reward -states: tiger-left tiger-right -actions: listen open-left open-right -observations: tiger-left tiger-right - -# ------------------------------------------------------------------- -# TRANSITIONS -T: * : * : * 0.0 -T: listen : tiger-left : tiger-left : 1.0 -T: listen : tiger-right : tiger-right : 1.0 -T: open-left : tiger-left : tiger-left : 0.5 -T: open-left : tiger-left : tiger-right : 0.5 -T: open-left : tiger-right : tiger-left : 0.5 -T: open-left : tiger-right : tiger-right : 0.5 -T: open-right : tiger-left : tiger-left : 0.5 -T: open-right : tiger-left : tiger-right : 0.5 -T: open-right : tiger-right : tiger-left : 0.5 -T: open-right : tiger-right : tiger-right : 0.5 - -# ------------------------------------------------------------------- -# OBSERVATIONS -O: * : * : * 0.0 -O: listen : tiger-left : tiger-left 0.85 -O: listen : tiger-left : tiger-right 0.15000000000000002 -O: listen : tiger-right : tiger-left 0.15000000000000002 -O: listen : tiger-right : tiger-right 0.85 -O: open-left : tiger-left : tiger-left 0.5 -O: open-left : tiger-left : tiger-right 0.5 -O: open-left : tiger-right : tiger-left 0.5 -O: open-left : tiger-right : tiger-right 0.5 -O: open-right : tiger-left : tiger-left 0.5 -O: open-right : tiger-left : tiger-right 0.5 -O: open-right : tiger-right : tiger-left 0.5 -O: open-right : tiger-right : tiger-right 0.5 - -# ------------------------------------------------------------------- -# REWARDS -R: * : * : * : * 0.0 -R: listen : tiger-left : tiger-left : tiger-left -1.0 -R: listen : tiger-left : tiger-left : tiger-right -1.0 -R: listen : tiger-left : tiger-right : tiger-left -1.0 -R: listen : tiger-left : tiger-right : tiger-right -1.0 -R: listen : tiger-right : tiger-left : tiger-left -1.0 -R: listen : tiger-right : tiger-left : tiger-right -1.0 -R: listen : tiger-right : tiger-right : tiger-left -1.0 -R: listen : tiger-right : tiger-right : tiger-right -1.0 -R: open-left : tiger-left : tiger-left : tiger-left -100.0 -R: open-left : tiger-left : tiger-left : tiger-right -100.0 -R: open-left : tiger-left : tiger-right : tiger-left -100.0 -R: open-left : tiger-left : tiger-right : tiger-right -100.0 -R: open-left : tiger-right : tiger-left : tiger-left 10.0 -R: open-left : tiger-right : tiger-left : tiger-right 10.0 -R: open-left : tiger-right : tiger-right : tiger-left 10.0 -R: open-left : tiger-right : tiger-right : tiger-right 10.0 -R: open-right : tiger-left : tiger-left : tiger-left 10.0 -R: open-right : tiger-left : tiger-left : tiger-right 10.0 -R: open-right : tiger-left : tiger-right : tiger-left 10.0 -R: open-right : tiger-left : tiger-right : tiger-right 10.0 -R: open-right : tiger-right : tiger-left : tiger-left -100.0 -R: open-right : tiger-right : tiger-left : tiger-right -100.0 -R: open-right : tiger-right : tiger-right : tiger-left -100.0 -R: open-right : tiger-right : tiger-right : tiger-right -100.0 - diff --git a/test/reader.jl b/test/reader.jl deleted file mode 100644 index 76f0109..0000000 --- a/test/reader.jl +++ /dev/null @@ -1,45 +0,0 @@ -using POMDPs -using POMDPFiles -using POMDPTools -using POMDPModels -using Test - -pomdpfiles = filter(endswith(".pomdp"), readdir(TEST_SOURCES; join=true)) - -@testset "Reading \"Litmann's 1D POMDP\"" begin - file = first(filter(endswith("1d.pomdp"), pomdpfiles)) - pomdp = read_pomdp(file) - - @test has_consistent_distributions(pomdp; atol=1e-5) - - @test discount(pomdp) == 0.75 - @test length(states(pomdp)) == 4 - @test length(actions(pomdp)) == 2 - @test length(observations(pomdp)) == 2 - - @test POMDPTools.has_consistent_distributions(pomdp, atol=1e-5) - - T = transition(pomdp, 1, 1) - @test pdf(T, 1) == 1.0 && rand(T) == 1 - O = observation(pomdp, 1, 1) - @test pdf(O, 1) == 1.0 && rand(O) == 1 - @test reward(pomdp, 1, 1) == 1.0 -end - -@testset "Reading \"Parr & Russell's POMDP\"" begin - file = first(filter(endswith("parr95.95.pomdp"), pomdpfiles)) - pomdp = read_pomdp(file) - - @test has_consistent_distributions(pomdp; atol=1e-5) - - @test discount(pomdp) == 0.95 - @test length(actions(pomdp)) == 3 - @test length(observations(pomdp)) == 6 - @test length(states(pomdp)) == 7 - - T = transition(pomdp, 1, 1) - @test pdf(T, 2) == 0.5 && in(rand(T), [2, 3]) - O = observation(pomdp, 1, 1) - @test pdf(O, 1) == 1.0 && rand(O) == 1 - @test reward(pomdp, 6, 1) == 2.0 -end diff --git a/test/runtests.jl b/test/runtests.jl index 76ac5e9..9ef70e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,18 +1,68 @@ -using POMDPFiles -using POMDPModels -using POMDPTools -using Test -using Downloads +using HTTP, Gumbo, Cascadia, CodecZlib, Tar +using POMDPs, POMDPFiles, Test +using Distributions, StatsBase -# Tests drawn from https://pomdp.org/examples/ +include("aux_func.jl") -const TEST_SOURCES = joinpath(dirname(@__FILE__), "sources") -const TEST_DATA = joinpath(dirname(@__FILE__), "data") +regex_filename = r"/([^/]*?)\.[Pp][Oo][Mm][Dd][Pp]$" -println("Running tests:") +ff_pomdp, ff_gz, ff_tar_gz = save_files(false) +files_path = [ff_pomdp, ff_gz, ff_tar_gz...] +all_files_path = read_pomdp_dir(files_path) + +nn_individual_tests = ["mit", "hallway", "bulkhead.A", "baseball", "tiger.95"] +individual_tests = set_individual_tests(nn_individual_tests) -println("Running `reader` tests:") -include("reader.jl") +problems = ["concert", "ejs1", "ejs2", "ejs4", "ejs5", "ejs6", "ejs7"] -println("Running `writer` tests:") -include("writer.jl") \ No newline at end of file +for file_path in all_files_path + ff_name = match(regex_filename, file_path).captures[1] + + if !isnothing(ff_name) && !(ff_name in problems) + @testset "Testing file: $(ff_name)" begin + pomdp_read = SWildcardArrayPOMDP(file_path) + + if ff_name in nn_individual_tests + for key in keys(individual_tests[ff_name]) + if isequal(key, "T") + @test all([pomdp_read.pomdp.T[k...] == vv for (k,vv) in zip(keys(individual_tests[ff_name]["T"]), values(individual_tests[ff_name]["T"]))]) + end + + if isequal(key, "O") + @test all([pomdp_read.pomdp.O[k...] == vv for (k,vv) in zip(keys(individual_tests[ff_name]["O"]), values(individual_tests[ff_name]["O"]))]) + end + + if isequal(key, "R") + @test all([pomdp_read.pomdp.R[k...] == vv for (k,vv) in zip(keys(individual_tests[ff_name]["R"]), values(individual_tests[ff_name]["R"]))]) + end + end + end + + if ff_name in ["baseball"] + vec_states = sample(states(pomdp_read), 100, replace=false) + vec_actions = sample(actions(pomdp_read), 6, replace=false) + elseif ff_name in ["aloha.30"] + vec_states = sample(states(pomdp_read), Int(ceil(0.3*pomdp_read.pomdp.ns)), replace=false) + vec_actions = sample(actions(pomdp_read), Int(ceil(0.3*pomdp_read.pomdp.na)), replace=false) + else + vec_states = states(pomdp_read) + vec_actions = actions(pomdp_read) + end + + if !isempty(pomdp_read.pomdp.support_initialstate) + prob = [val.second for val in initialstate(pomdp_read)] + @test POMDPFiles.test_if_probability(prob) + end + + for s in vec_states + for a in vec_actions + prob = [pomdp_read.pomdp.T[a,s, sp] for sp in states(pomdp_read)] + @test POMDPFiles.test_if_probability(prob) + + prob = [pomdp_read.pomdp.O[a,s, o] for o in observations(pomdp_read)] + @test POMDPFiles.test_if_probability(prob) + end + end + end + end +end \ No newline at end of file diff --git a/test/sources/1d.pomdp b/test/sources/1d.pomdp deleted file mode 100644 index 4cdecfe..0000000 --- a/test/sources/1d.pomdp +++ /dev/null @@ -1,28 +0,0 @@ -# Downloaded from https://pomdp.org/examples/1d.POMDP -# Michael's 1D maze - -discount: 0.75 -values: reward -states: left middle right goal -actions: w0 e0 -observations: nothing goal - -T: w0 -1.0 0.0 0.0 0.0 -1.0 0.0 0.0 0.0 -0.0 0.0 0.0 1.0 -0.333333 0.333333 0.333333 0.0 - -T: e0 -0.0 1.0 0.0 0.0 -0.0 0.0 0.0 1.0 -0.0 0.0 1.0 0.0 -0.333333 0.333333 0.333333 0.0 - -O: * -1.0 0.0 -1.0 0.0 -1.0 0.0 -0.0 1.0 - -R: * : * : goal : goal 1.0 diff --git a/test/sources/parr95.95.pomdp b/test/sources/parr95.95.pomdp deleted file mode 100644 index 9542ddb..0000000 --- a/test/sources/parr95.95.pomdp +++ /dev/null @@ -1,48 +0,0 @@ -# Downloaded from https://pomdp.org/examples/parr95.95.POMDP -# This example is from Parr and Russell's paper on the SPOVA RL -# algorithm from IJCAI'95. - -discount: 0.95 -values: reward -states: I hi-A lo-A C D plus1 minus1 -actions: a b c -observations: I A C D plus1 minus1 - -start include: I - -T : * : I : hi-A 0.5 -T : * : I : lo-A 0.5 - -T : a : hi-A : C 1.0 -T : b : hi-A : minus1 1.0 -T : c : hi-A : plus1 1.0 - -T : a : lo-A : D 1.0 -T : b : lo-A : plus1 1.0 -T : c : lo-A : minus1 1.0 - -T : a : C : hi-A 1.0 -T : b : C : I 1.0 -T : c : C : I 1.0 - -T : a : D : lo-A 1.0 -T : b : D : I 1.0 -T : c : D : I 1.0 - -T : * : plus1 : I 1.0 - -T : * : minus1 : I 1.0 - -O : * : I : I 1.0 -O : * : hi-A : A 1.0 -O : * : lo-A : A 1.0 -O : * : C : C 1.0 -O : * : D : D 1.0 -O : * : plus1 : plus1 1.0 -O : * : minus1 : minus1 1.0 - -# The paper has +1 and -1, but the SPOVA stuff requires -# non-negative rewards -R: * : plus1 : * : * 2.0 -R: * : minus1 : * : * 0.0 - diff --git a/test/writer.jl b/test/writer.jl deleted file mode 100644 index b05de2d..0000000 --- a/test/writer.jl +++ /dev/null @@ -1,26 +0,0 @@ -using POMDPFiles -using POMDPTools -using POMDPModels -using Test -using SHA - -@testset "Writing TigerPOMDP" begin - pomdp = TigerPOMDP() - @testset "Numeric Representation" begin - filename = tempname() - numericprint(filename, pomdp) - file_sha = open(sha256, filename) - disk_sha = open(sha256, joinpath(TEST_DATA, "tiger.numeric.pomdp")) - @test all(file_sha .== disk_sha) - end - @testset "Readable Representation" begin - filename = tempname() - sname = (idx) -> ["tiger-left", "tiger-right"][idx + 1] - aname = (idx) -> ["listen", "open-left", "open-right"][idx + 1] - oname = (idx) -> ["tiger-left", "tiger-right"][idx + 1] - symbolicprint(filename, pomdp; sname=sname, aname=aname, oname=oname) - file_sha = open(sha256, filename) - disk_sha = open(sha256, joinpath(TEST_DATA, "tiger.symbolic.pomdp")) - @test all(file_sha .== disk_sha) - end -end diff --git a/tt.jl b/tt.jl new file mode 100644 index 0000000..c055796 --- /dev/null +++ b/tt.jl @@ -0,0 +1,11 @@ +using HTTP, POMDPs, POMDPFiles + +mktempdir() do tmp_dir + url = "https://www.pomdp.org/examples/mini-hall2.POMDP" + tmp_file_name = joinpath(tmp_dir, "paint.95.POMDP") + HTTP.download(url, tmp_file_name) + + pomdp = SWildcardArrayPOMDP(tmp_file_name) + states(pomdp) +end +