Skip to content

Commit

Permalink
added LeafBelief so that you can get the state at a leaf node (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg authored May 18, 2021
1 parent 4a7ca07 commit 29650be
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BasicPOMCP"
uuid = "d721219e-3fc6-5570-a8ef-e5402f47c49e"
repo = "https://github.com/JuliaPOMDP/BasicPOMCP.jl"
version = "0.3.4"
version = "0.3.5"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Minimal_Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@
}
],
"source": [
"filter = SIRParticleFilter(pomdp, 1000)\n",
"filter = BootstrapFilter(pomdp, 1000)\n",
"for (s,a,r,sp,o) in stepthrough(pomdp, planner, filter, \"s,a,r,sp,o\")\n",
" @show (s,a,r,sp,o)\n",
"end"
Expand Down
40 changes: 33 additions & 7 deletions src/BasicPOMCP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export
default_action,

BeliefNode,
AOHistoryBelief,
LeafNodeBelief,
AbstractPOMCPSolver,

PORollout,
Expand All @@ -55,7 +55,10 @@ export

D3Tree,
node_tag,
tooltip_tag
tooltip_tag,

# deprecated
AOHistoryBelief

abstract type AbstractPOMCPSolver <: Solver end

Expand Down Expand Up @@ -147,14 +150,37 @@ function POMCPTree(pomdp::POMDP, b, sz::Int=1000)
)
end

struct AOHistoryBelief{H<:NTuple{<:Any, <:NamedTuple{(:a, :o)}}}
struct LeafNodeBelief{H, S} <: AbstractParticleBelief{S}
hist::H
sp::S
end
POMDPs.currentobs(h::LeafNodeBelief) = h.hist[end].o
POMDPs.history(h::LeafNodeBelief) = h.hist

# particle belief interface
ParticleFilters.n_particles(b::LeafNodeBelief) = 1
ParticleFilters.particles(b::LeafNodeBelief) = (b.sp,)
ParticleFilters.weights(b::LeafNodeBelief) = (1.0,)
ParticleFilters.weighted_particles(b::LeafNodeBelief) = (b.sp=>1.0,)
ParticleFilters.weight_sum(b::LeafNodeBelief) = 1.0
ParticleFilters.weight(b::LeafNodeBelief, i) = i == 1 ? 1.0 : 0.0

function ParticleFilters.particle(b::LeafNodeBelief, i)
@assert i == 1
return b.sp
end
POMDPs.currentobs(h::AOHistoryBelief) = h.hist[end].o
POMDPs.history(h::AOHistoryBelief) = h.hist

function insert_obs_node!(t::POMCPTree, pomdp::POMDP, ha::Int, o)
acts = actions(pomdp, AOHistoryBelief(tuple((a=t.a_labels[ha], o=o))))
POMDPs.mean(b::LeafNodeBelief) = b.sp
POMDPs.mode(b::LeafNodeBelief) = b.sp
POMDPs.support(b::LeafNodeBelief) = (b.sp,)
POMDPs.pdf(b::LeafNodeBelief{<:Any, S}, s::S) where S = float(s == b.sp)
POMDPs.rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:LeafNodeBelief}) = s[].sp

# old deprecated name
const AOHistoryBelief = LeafNodeBelief

function insert_obs_node!(t::POMCPTree, pomdp::POMDP, ha::Int, sp, o)
acts = actions(pomdp, LeafNodeBelief(tuple((a=t.a_labels[ha], o=o)), sp))
push!(t.total_n, 0)
push!(t.children, sizehint!(Int[], length(acts)))
push!(t.o_labels, o)
Expand Down
2 changes: 1 addition & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function simulate(p::POMCPPlanner, s, hnode::POMCPObsNode, steps::Int)

hao = get(t.o_lookup, (ha, o), 0)
if hao == 0
hao = insert_obs_node!(t, p.problem, ha, o)
hao = insert_obs_node!(t, p.problem, ha, sp, o)
v = estimate_value(p.solved_estimator,
p.problem,
sp,
Expand Down
21 changes: 19 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using POMDPSimulators
using POMDPModelTools
using POMDPTesting
using POMDPLinter: @requirements_info, @show_requirements, requirements_info
using ParticleFilters: n_particles, particles, particle, weights, weighted_particles, weight_sum, weight

import POMDPs:
transition,
Expand All @@ -24,7 +25,7 @@ import POMDPs:
struct ConstObsPOMDP <: POMDP{Bool, Symbol, Bool} end
updater(problem::ConstObsPOMDP) = DiscreteUpdater(problem)
initialstate(::ConstObsPOMDP) = BoolDistribution(0.0)
transition(p::ConstObsPOMDP, s::Bool, a::Symbol) = BoolDistribution(0.5)
transition(p::ConstObsPOMDP, s::Bool, a::Symbol) = BoolDistribution(0.0)
observation(p::ConstObsPOMDP, a::Symbol, sp::Bool) = BoolDistribution(1.0)
reward(p::ConstObsPOMDP, s::Bool, a::Symbol, sp::Bool) = 1.
discount(p::ConstObsPOMDP) = 0.9
Expand All @@ -50,10 +51,26 @@ end;

@testset "belief dependent actions" begin
pomdp = ConstObsPOMDP()
function POMDPs.actions(m::ConstObsPOMDP, b::AOHistoryBelief)
function POMDPs.actions(m::ConstObsPOMDP, b::LeafNodeBelief)
@test currentobs(b) == true
@test history(b)[end].o == true
@test history(b)[end].a == :the_only_action
@test mean(b) == 0.0
@test mode(b) == 0.0
@test only(support(b)) == false
@test pdf(b, false) == 1.0
@test pdf(b, true) == 0.0
@test rand(b) == false
@test n_particles(b) == 1
@test only(particles(b)) == false
@test only(weights(b)) == 1.0
@test only(weighted_particles(b)) == (false => 1.0)
@test weight_sum(b) == 1.0
@test weight(b, 1) == 1.0
@test particle(b, 1) == false

# old type name - this can be removed when upgrading versions
@test b isa AOHistoryBelief
return actions(m)
end

Expand Down

2 comments on commit 29650be

@zsunberg
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37002

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.5 -m "<description of version>" 29650be6aebbbbf111fb1fff920580c1e3d32007
git push origin v0.3.5

Please sign in to comment.