diff --git a/src/LaserTag.jl b/src/LaserTag.jl index eb086f8..66c5d5e 100644 --- a/src/LaserTag.jl +++ b/src/LaserTag.jl @@ -20,29 +20,23 @@ export LTState, CMeas, DMeas, - LaserTagVis, - - MoveTowards, + LaserTagVis, MoveTowards, MoveTowardsSampled, OptimalMLSolver, OptimalML, BestExpectedSolver, - BestExpected, - - DESPOTEmu, - - gen_lasertag, + BestExpected, DESPOTEmu, gen_lasertag, cpp_emu_lasertag, tikz_pic, n_clear_cells -const Coord = SVector{2, Int} -const CMeas = MVector{8, Float64} -const DMeas = MVector{8, Int} +const Coord = SVector{2,Int} +const CMeas = MVector{8,Float64} +const DMeas = MVector{8,Int} -const C_SAME_LOC = fill!(MVector{8, Float64}(undef), -1.0) -const D_SAME_LOC = fill!(MVector{8, Int64}(undef), -1) +const C_SAME_LOC = fill!(MVector{8,Float64}(undef), -1.0) +const D_SAME_LOC = fill!(MVector{8,Int64}(undef), -1) @auto_hash_equals struct LTState # XXX auto_hash_equals isn't correct for terminal robot::Coord @@ -72,16 +66,16 @@ obs_type(om::ObsModel) = obs_type(typeof(om)) include("distance_cache.jl") -@with_kw struct LaserTagPOMDP{M<:ObsModel, O<:Union{CMeas, DMeas}} <: POMDP{LTState, Int, O} - tag_reward::Float64 = 10.0 - step_cost::Float64 = 1.0 - discount::Float64 = 0.95 - floor::Floor = Floor(7, 11) - obstacles::Set{Coord} = Set{Coord}() - robot_init::Union{Coord, Nothing} = nothing - diag_actions::Bool = false - dcache::LTDistanceCache = LTDistanceCache(floor, obstacles) - obs_model::M = DESPOTEmu(floor, 2.5) +@with_kw struct LaserTagPOMDP{M<:ObsModel,O<:Union{CMeas,DMeas}} <: POMDP{LTState,Int,O} + tag_reward::Float64 = 10.0 + step_cost::Float64 = 1.0 + discount::Float64 = 0.95 + floor::Floor = Floor(7, 11) + obstacles::Set{Coord} = Set{Coord}() + robot_init::Union{Coord,Nothing} = nothing + diag_actions::Bool = false + dcache::LTDistanceCache = LTDistanceCache(floor, obstacles) + obs_model::M = DESPOTEmu(floor, 2.5) end ltfloor(m::LaserTagPOMDP) = m.floor @@ -143,6 +137,17 @@ function POMDPs.reward(p::LaserTagPOMDP, s::LTState, a::Int, sp::LTState) end end + +# add a function to transform a LTState to a vector +function POMDPs.convert_s(T::Type{<:AbstractArray}, s::LTState, p::LaserTagPOMDP) + return convert(T, vcat(s.robot, s.opponent, s.terminal)) +end + +# transform a vector to a LTState +function POMDPs.convert_s(T::Type{LTState}, v::AbstractArray{Float64}, p::LaserTagPOMDP) + return LTState(Coord(v[1], v[2]), Coord(v[3], v[4]), v[5]) +end + POMDPs.isterminal(p::LaserTagPOMDP, s::LTState) = s.terminal POMDPs.discount(p::LaserTagPOMDP) = p.discount diff --git a/test/runtests.jl b/test/runtests.jl index 4a9f080..afe9cef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,14 @@ tikz_pic(LaserTagVis(p)) # discrete p = gen_lasertag() +# check convert_s function +b0 = initialstate(p) +s_test = rand(b0) +v_s_test = convert_s(Vector{Float64}, s_test, p) +s_back = convert_s(LTState, v_s_test, p) +@test s_back == s_test + + # check observation model consistency rng = MersenneTwister(12) N = 1_000_000