From 2b8c7d4379000c5b1383fb39598bde54c4aa9715 Mon Sep 17 00:00:00 2001 From: yangyou95 <37275563+yangyou95@users.noreply.github.com> Date: Tue, 13 Dec 2022 19:47:18 +0100 Subject: [PATCH 1/3] Update LaserTag.jl add convert_s function --- src/LaserTag.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/LaserTag.jl b/src/LaserTag.jl index eb086f8..96a1727 100644 --- a/src/LaserTag.jl +++ b/src/LaserTag.jl @@ -143,6 +143,11 @@ function POMDPs.reward(p::LaserTagPOMDP, s::LTState, a::Int, sp::LTState) end end +# add a function to transform a LTState to a vector +function POMDPs.convert_s(T::Type{<:AbstractArray}, s::LTState, p::LaserTagPOMDP) + return convert(T, vcat(s.robot, s.opponent, [s.terminal])) +end + POMDPs.isterminal(p::LaserTagPOMDP, s::LTState) = s.terminal POMDPs.discount(p::LaserTagPOMDP) = p.discount From c18fa55e1d612d95d372ba848a938b4189718ac6 Mon Sep 17 00:00:00 2001 From: yangyou95 <37275563+yangyou95@users.noreply.github.com> Date: Tue, 20 Dec 2022 23:55:41 +0100 Subject: [PATCH 2/3] Update LaserTag.jl --- src/LaserTag.jl | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/LaserTag.jl b/src/LaserTag.jl index 96a1727..2402011 100644 --- a/src/LaserTag.jl +++ b/src/LaserTag.jl @@ -20,29 +20,23 @@ export LTState, CMeas, DMeas, - LaserTagVis, - - MoveTowards, + LaserTagVis, MoveTowards, MoveTowardsSampled, OptimalMLSolver, OptimalML, BestExpectedSolver, - BestExpected, - - DESPOTEmu, - - gen_lasertag, + BestExpected, DESPOTEmu, gen_lasertag, cpp_emu_lasertag, tikz_pic, n_clear_cells -const Coord = SVector{2, Int} -const CMeas = MVector{8, Float64} -const DMeas = MVector{8, Int} +const Coord = SVector{2,Int} +const CMeas = MVector{8,Float64} +const DMeas = MVector{8,Int} -const C_SAME_LOC = fill!(MVector{8, Float64}(undef), -1.0) -const D_SAME_LOC = fill!(MVector{8, Int64}(undef), -1) +const C_SAME_LOC = fill!(MVector{8,Float64}(undef), -1.0) +const D_SAME_LOC = fill!(MVector{8,Int64}(undef), -1) @auto_hash_equals struct LTState # XXX auto_hash_equals isn't correct for terminal robot::Coord @@ -72,16 +66,16 @@ obs_type(om::ObsModel) = obs_type(typeof(om)) include("distance_cache.jl") -@with_kw struct LaserTagPOMDP{M<:ObsModel, O<:Union{CMeas, DMeas}} <: POMDP{LTState, Int, O} - tag_reward::Float64 = 10.0 - step_cost::Float64 = 1.0 - discount::Float64 = 0.95 - floor::Floor = Floor(7, 11) - obstacles::Set{Coord} = Set{Coord}() - robot_init::Union{Coord, Nothing} = nothing - diag_actions::Bool = false - dcache::LTDistanceCache = LTDistanceCache(floor, obstacles) - obs_model::M = DESPOTEmu(floor, 2.5) +@with_kw struct LaserTagPOMDP{M<:ObsModel,O<:Union{CMeas,DMeas}} <: POMDP{LTState,Int,O} + tag_reward::Float64 = 10.0 + step_cost::Float64 = 1.0 + discount::Float64 = 0.95 + floor::Floor = Floor(7, 11) + obstacles::Set{Coord} = Set{Coord}() + robot_init::Union{Coord,Nothing} = nothing + diag_actions::Bool = false + dcache::LTDistanceCache = LTDistanceCache(floor, obstacles) + obs_model::M = DESPOTEmu(floor, 2.5) end ltfloor(m::LaserTagPOMDP) = m.floor @@ -143,9 +137,15 @@ function POMDPs.reward(p::LaserTagPOMDP, s::LTState, a::Int, sp::LTState) end end + # add a function to transform a LTState to a vector function POMDPs.convert_s(T::Type{<:AbstractArray}, s::LTState, p::LaserTagPOMDP) - return convert(T, vcat(s.robot, s.opponent, [s.terminal])) + return convert(T, vcat(s.robot, s.opponent, s.terminal)) +end + +# transform a vector to a LTState +function POMDPs.convert_s(T::Type{LTState}, v::AbstractArray{Float64}, p::LaserTagPOMDP) + return LTState([v[1], v[2]], [v[3], v[4]], v[5]) end POMDPs.isterminal(p::LaserTagPOMDP, s::LTState) = s.terminal From da28f02243ea4d20cf45adc08791299d98a647f1 Mon Sep 17 00:00:00 2001 From: yangyou95 <37275563+yangyou95@users.noreply.github.com> Date: Tue, 17 Jan 2023 13:32:59 +0100 Subject: [PATCH 3/3] update LaserTag convert_s and test functions --- src/LaserTag.jl | 2 +- test/runtests.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/LaserTag.jl b/src/LaserTag.jl index 2402011..66c5d5e 100644 --- a/src/LaserTag.jl +++ b/src/LaserTag.jl @@ -145,7 +145,7 @@ end # transform a vector to a LTState function POMDPs.convert_s(T::Type{LTState}, v::AbstractArray{Float64}, p::LaserTagPOMDP) - return LTState([v[1], v[2]], [v[3], v[4]], v[5]) + return LTState(Coord(v[1], v[2]), Coord(v[3], v[4]), v[5]) end POMDPs.isterminal(p::LaserTagPOMDP, s::LTState) = s.terminal diff --git a/test/runtests.jl b/test/runtests.jl index 4a9f080..afe9cef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,14 @@ tikz_pic(LaserTagVis(p)) # discrete p = gen_lasertag() +# check convert_s function +b0 = initialstate(p) +s_test = rand(b0) +v_s_test = convert_s(Vector{Float64}, s_test, p) +s_back = convert_s(LTState, v_s_test, p) +@test s_back == s_test + + # check observation model consistency rng = MersenneTwister(12) N = 1_000_000