-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WrappedContext and separation between sampling and evaluation #249
Conversation
IIRC it is not necessarily the log joint, some samplers require a special log density evaluation by default (which also doesn't belong to any of the other categories). Probably this could be solved though with a SamplingContext, as suggested before.
This seems reasonable but I wonder: Do we actually need another context? Isn't it sufficient to have an interface that allows to |
BTW I think also with |
Another observation: it also seems that one of the main problems with the |
Ah, good point 👍
Yes because you might want to dispatch based on the
I like this. I was thinking that the I'll make an attempt at adding |
Do we? I thought one would just call something like |
Yeah okay fair, if introduce a bunch of |
Do you have an example where type inference breaks? I would have assumed that it's not a problem for the compiler, at least not with reasonably deep contexts (eg there are a lot of recursive definitions in SciML which work fine). |
I don't have an example right now, but I can send you some pictures of my battlescars from implementing compositions in Bijectors.jl! I don't like not adding it to the type though, as it's more difficult to add back later if we realize we need vs. removing it if we realize we don't need it. But I agree, we shouldn't do it without reason. I'll do some checks 👍 |
Seems like our But if you remove the julia> @code_warntype PrefixContext{:c}(MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}())))
Variables
#self#::Type{PrefixContext{:c, C, LeafCtx} where {C, LeafCtx}}
ctx::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}
Body::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}
1 ─ %1 = $(Expr(:static_parameter, 1))::Core.Const(:c)
│ %2 = DynamicPPL.typeof(ctx)::Core.Const(MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext})
│ %3 = Core.apply_type(DynamicPPL.PrefixContext, %1, %2, $(Expr(:static_parameter, 2)))::Core.Const(PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext})
│ %4 = Core.fieldtype(%3, 1)::Core.Const(MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext})
│ %5 = Base.convert(%4, ctx)::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}
│ %6 = %new(%3, %5)::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}
└── return %6
julia> @code_warntype PrefixContext{:d}(PrefixContext{:c}(MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}()))))
Variables
#self#::Type{PrefixContext{:d, C, LeafCtx} where {C, LeafCtx}}
ctx::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}
Body::PrefixContext{_A, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext} where _A
1 ─ %1 = DynamicPPL.Symbol($(Expr(:static_parameter, 2)), DynamicPPL.PREFIX_SEPARATOR, $(Expr(:static_parameter, 1)))::Symbol
│ %2 = Core.apply_type(DynamicPPL.PrefixContext, %1)::Type{PrefixContext{_A, C, LeafCtx} where {C, LeafCtx}} where _A
│ %3 = Base.getproperty(ctx, :ctx)::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}
│ %4 = (%2)(%3)::PrefixContext{_A, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext} where _A
└── return %4 That is not a lot of nesting before it fails 😕 |
But for julia> ctx = MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}())))))))))))));
julia> @code_warntype DynamicPPL.unwrap(ctx)
Variables
#self#::Core.Const(DynamicPPL.unwrap)
ctx::MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}
Body::DefaultContext
1 ─ %1 = Base.getproperty(ctx, :ctx)::MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}
│ %2 = DynamicPPL.unwrap(%1)::Core.Const(DefaultContext())
└── return %2 Hmmmmmmm. Not sure what to think. |
If the bug is fixed, both examples infer correctly it seems (with an additional Edit: Even the MiniBatch example infers correctly. |
Yeah, because I guess it starts using EDIT: Just for the record, in the example where it failed I explicitly removed the |
But even so, isn't having |
Probably in some extreme cases but I'm not sure if this is relevant (already the examples above seem a bit constructed). Moreover, the compiler has to specialize all methods for different leaf types, so possibly there is a lot of unnecessary recompilation due to the additional type parameters whereas actually one would need specialize on the leaf type only in certain functions. |
Another argument for not having a |
Hmm, but we're already specializing on the child-context so at least shouldn't make things harder? Also there are cases where we don't actually need to
See I'm kind of the opposite opinion because generally how you work with a f(ctx::AbstractContext) = # actual computation ...
f(ctx::WrappedContext) = f(unwrap(ctx)) without worrying about |
While it's true that it amount of methods that have to be compiled does not increase if the information is redundant, additional type parameters still put more stress on the compiler. It has to be carried around everywhere and, as far as I know, the compiler doesn't realize that it does not increase the information content.
I'm not completely sure if I understand you correctly, but it seems this would just motivate to have both
You just specialize |
Regarding the last point, one can even add a trait That's common practice e.g. for iterators ( |
But say I implement a method for
IMHO it wouldn't be a fallback; it would be the implementation. Other than simple rules for 1-level nesting, there's no other good way to implement this on a case-by-case basis for nested unwrappedtype(context) = typeof(unwrap(context)) you need to compile Also, just to note: we're not really putting any restrictions on the relationships between the contexts here. We can always just specialize the methods for any particular
Yeaaah but this can quickly become unwieldy though + a user can't add to the union => they need to manually overload that method for their particular
Sure! But as you've pointed out, dispatch is preferable to Also there is discussion regarding these two approaches in JuliaLang/julia#31563. Of course there's probably going to be greater latency with overloads of arrays, but this comment (JuliaLang/julia#31563 (comment)) indicates that the "intuition" that a |
Haha I see you're not convinced by Julia's (sometimes a bit excessive, I guess) use of traits 😄
I think one has to separate two things here: I don't like the pattern function f(...)
...
y = g(...)
if y isa SomeType
...
elseif y isa SomeOtherType
...
else
...
end
...
end since one has to reimplement the complete function function f(...)
...
y = g(...)
z = h(y, ...) # or just h(y, ...)
...
end then one has to implement only However, in the wrapper example f(context::AbstractContext) = ...
f(context::WrapperContext) = ... the situation is different. As mentioned, of course, one can always specialize struct IsWrapper end
struct IsNoWrapper end
iswrapper(context::AbstractContext) = IsNoWrapper()
iswrapper(context::PrefixContext) = IsWrapper()
f(context::AbstractContext) = f(context, iswrapper(context))
function f(context::AbstractContext, ::IsWrapper)
....
end
function f(context::AbstractContext, ::IsNoWrapper)
...
end Then new contexts can declare if they are a wrapper or not and then benefit from the default implementation, and, of course, it is still possible to just implement iswrapper(context::AbstractContext) = false
iswrapper(context::PrefixContext) = true
f(context::AbstractContext) = f(context, Val{iswrapper(context)}())
function f(context::AbstractContext, iswrapper::Val{false})
....
end
function f(context::AbstractContext, iswrapper::Val{true})
...
end And if there are really only these two possibilities - either it is a wrapper or not - then this can be shortened to iswrapper(context::AbstractContext) = false
iswrapper(context::PrefixContext) = true
function f(context::AbstractContext)
if iswrapper(context)
...
else
...
end
end The compiler will elide the unused branch automatically, so this is equivalent to the much more verbose trait implementation above. The main different between this
The main advantage is that you do not enforce super types. Imagine there is some other orthogonal behaviour that should be shared among different contexts such as if they sample values or just evaluate them. Then you can't encode both the wrapper behaviour and this other behaviour in the type hierarchy and so you have to decide for one of the default implementations and reimplement the other one for your context. Instead, if you use functions such as Additionally, I'd argue that in the |
Whelp, I can't say no to the interface-approach now that you've written such a long and nice explanation 😞 But I'm with you on the pattern and am also aware of it (though it was still useful to read what you wrote above; very nicely laid out all the different approaches). My point is that I want the separation between any other context and a Another question, do you have any concrete ideas on how things should look internally with a
The issue that I'm struggling a bit to resolev right now is that Immediate idea of solution is to introduce Also, I'm guessing we're going to leave |
Just a quick comment:
IIRC this is only the case since we drop the context when forwarding |
Yes, this is because we drop it when we forward to |
Seems better to just let them work with two different versions of |
Also, say we overload |
I wonder, do we need a dedicated evaluation context at all? Could just by default every context be an evaluation context, and if you want to sample you have to combine it with a sampler in a |
I do like this! But this is like a more complicated I prefer inner-most because:
I'm also unsatisfied with these being the outer-most context rather than inner-most. But yes, the above too.
IMO this is a bit unfair though. You're technically not wrong, but in reality, essentially no contexts will actually touch Similarly for primitive contexts, the bug will always be in
You mean if they only want to implement samplers? Sure, but it's the case if you want to implement contexts. I think we really should allow custom contexts to be implemented "easily" (though we should still be careful of adding them to DPPL), as there are some really neat stuff you could do with it.
I mean, I like this idea irregardless of this PR and our discussions!:)
Crap, I meant to changes this but clearly forgot to; sorry. You're right, it's not "less flexible" in what you can do with it; of course, if you're overloading Seperate thingBtw, another question I would like to hear your thoughts on, independent of the ongoing
This is achieved as follows:
This has a couple of nice features IMO:
|
I think the order of how the contexts are applied isn't significantly different in the proposals, is it?
Could you explain what you mean with "Simpler
I don't think any error checking is needed at construction time. |
"wrapped context is applied after the
So same functionality, but IMO the first one seems easier/more intuitive. Also, how does the wrapper-contexts know if they're sampling or not? I.e. how can a wrapper-context say "stop wrapping this in Btw, where did the
IMO And yeah, functionality is the same. But do you really think that rewrapping at every step along the way and then potentially stopping at some later point using the Functors.jl approach is easier to follow/more intuitive than "change leaf, i.e. call Btw, for the record: we can remove the
But like, why though? This doesn't seem like a good argument when there is an approach that could make it so that it never happens, right?
But bugs will happen, even if we only use it internally. So if there's a way to ensure that a particular error never occurs, that seems nice, no? |
Yo! What about renaming |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work - I left many comments below.
src/context_implementations.jl
Outdated
|
||
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | ||
accumulate the log probability, and return the sampled value. | ||
|
||
Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. | ||
""" | ||
function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) | ||
value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi) | ||
function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice improvement!
src/context_implementations.jl
Outdated
""" | ||
tilde_assume(rng, ctx, sampler, right, vn, inds, vi) | ||
tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) | ||
|
||
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | ||
accumulate the log probability, and return the sampled value. | ||
|
||
Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be updated?
Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. | |
Falls back to `tilde_assume(rng, ctx, sampler, right, vn, inds, vi)`. |
src/context_implementations.jl
Outdated
function _tilde(rng, sampler, right, vn::VarName, vi) | ||
return assume(rng, sampler, right, vn, vi) | ||
function tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vn, inds, vi) | ||
return assume(rng, sampler, right, left, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a good idea to wrap sampler
and rng
into SamplingContext
. It makes function signature more consistent, and code clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The plan is to move sampler and rng into SamplingContext yeah, but we still want to unwrap those arguments before reaching assume
. We don't want developers outside of DPPL to have to use contexts; contexts should be something they never really see. Contexts should also never change the behavior of assume
.
src/context_implementations.jl
Outdated
function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi) | ||
return _tilde(rng, sampler, right.dist, right.name, vi) | ||
function tilde_assume( | ||
rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, inds, vi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, we can wrap sampler
into EvaluationContext
or remove it (if it's not needed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we end up putting the sampler
and rng
into SamplingContext
, I agree 👍
For removing it: def something we want to do at some point, but right now it's a bit difficult I think.
src/context_implementations.jl
Outdated
function tilde_observe( | ||
ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi | ||
) | ||
return observe(sampler, right, left, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to assume
, might be good to wrap sampler
into ctx
, and pass ctx
through the call stack instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't want to pass ctx
to observe
nor assume
. Agree with simplifying tilde_observe
and tilde_observe!
by removing sampler
from the signature though.
@@ -0,0 +1,54 @@ | |||
function tilde_assume( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe split the code into multiple files in future PRs, so that we can keep track of actual code changes in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I agree with you, but I did this because it was already starting to look pretty nasty in the review-section 😕 So many changes to the diffs didn't "match" up.
@@ -4,80 +4,49 @@ | |||
The `DefaultContext` is used by default to compute log the joint probability of the data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We no longer have a DefaultContext
.
include("contexts/prior.jl") | ||
include("contexts/likelihood.jl") | ||
include("contexts/minibatch.jl") | ||
include("contexts/prefix.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have two sub folders for contexts related code at the moment, i.e. contexts/
and contexts_implementaiton
. I suggest that we merge these two sub folders into one contexts/
. Also, we should also merge contexts/prior.jl
and context_implementations/prior.jl
into one file contexts/prior.jl
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because of inclusion order, e.g. we need the definition of some of the contexts in model.jl
but at that point we cannot yet define impls of tilde_assume
, etc., so that's why we have this split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could possibly be simplified though! Just figured I'd stick with what we were already for this PR.
@@ -1,13 +1,15 @@ | |||
# Context version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe consider moving this file into contexts/loglikelihoods.jl
.
@@ -1399,90 +1399,3 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) | |||
|
|||
return indices | |||
end | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a replacement mechanism for setval_and_resample!
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's no longer needed, since the _and_resample!
part is now just defined by using a SamplingContext
.
Some notes for myself for the meeting (no need to reply in advance, will probably also be extended and modified): General comment:
Alternatively, it might be sufficient (?) to be able to pass a value only to the non-wrapper contexts. An example with the non- struct ConditionContext{names,V<:NamedTuple{names},C<:AbstractContext} <: AbstractContext
values::V
context::C
end
function unwrap_childcontext(context::ConditionContext{names}) where {names}
child = context.context
function reconstruct_conditioncontext(c::AbstractContext)
return ConditionContext(context.values, c)
end
return child, reconstruct_conditioncontext
end
# drop `SamplingContext` if not needed
function tilde_assume(context::SamplingContext{<:ConditionContext{names}}, right, vn::VarName{sym}, inds, vi) where {sym,names}
return if sym in names
tilde_assume(context.context, right, vn, inds, vi)
else
c, reconstruct_context = unwrap_childcontext(context)
child_of_c, _ = unwrap_childcontext(c)
tilde_assume(reconstruct_context(child_of_c), right, vn, inds, vi)
end
end
# propagate `ConditionContext` to leaf contexts such as `JointContext` etc.
function tilde_assume(context::ConditionContext{names}, right, vn::VarName{sym}, inds, vi) where {names,sym}
c, reconstruct_context = unwrap_childcontext(context)
child_of_c, reconstruct_c = unwrap_childcontext(c)
return if child_of_c === nothing
if sym in names
# pass the value to the leaf context
tilde_assume(c, right, vn, inds, vi, context.values[sym])
else
# pass no value
tilde_assume(c, right, vn, inds, vi)
end
else
tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi)
end
end |
I'm going to do it anyways, even if just to organize my own thoughts on your notes:)
A bit uncertain what you mean here. AFAIK we wrap the RNG and sampler in
You mean situations where you want to do something like function tilde_assume(ctx::MyContext, ...)
f(...)
value, logp = tilde_assume(childcontext(ctx), ...)
g(...)
return value
end where the "primitive" context can be changed within the nested A couple of notes on this:
Or am I misunderstanding something here?
IMO yes, because:
It also opens up for the opportunity of being a bit more explicit about what should be sampled, etc., e.g. we could use
Technically, yes, but:
A bit confused here; why can't this be achieved through
True, but as mentioned above, I think there are cases where you'd like to also sample.
Very much in favour of removing it if possible. I added/kept it because AFAIK it's still being used for filtering based on variables so figured it was necessary to ensure that logp computations were performed correctly.
EDIT: Ah, you mean the "inner-most" one/"primitive" context. Yeah actually this might be a great idea! Don't see a reason why this wouldn't work:) EDIT 2: Eeeh actually it's bit awkward with contexts such as Btw, one thing I don't think I've heard your thoughts on yet @devmotion: what do you think about the suggestion of passing a |
Just to summarize, we decided to split the work into several different PRs:
The final part is then passing an additional
AFAIK, the
You mean to be returned? If so, yeah I agree. But I would say that this would fall under the "allowing post-hooks" in the tilde-statments, i.e. maybe something to think about after we have the override-feature as it could be trivially added by just returning more information.
We could insert the
As mentioned above, I'm pretty certain we can just remove
Won't this lead to |
Closing this in favour of the PRs listed above. |
This PR introduces the simplification of the tilde-callstack as discussed in #249. Copy-pasted from there: - Remove unnecessary complexity in `~` implementation. - Current calling hierarchy for a `~` statement is: - `tilde_assume` -> `tilde(rng, ...)` -> `_tilde(rng, ...)` -> `assume` - `tilde_observe` -> `tilde(...)` -> `_tilde(...)` -> `observe` - Similarly for `dot_tilde_assume` and `dot_tilde_observe`. - This is super-confusing and difficult to debug. - `_tilde` is currently only used for `NamedDist` to allow overriding the variable-name used for a particular `~` statement. - Propose the following changes: - Remove `_tilde` and handle `NamedDist` _before_ calling `tilde_assume`, etc. by using a `unpack_right_vns` (and `unpack_right_left_vns` for dot-statements) (thanks to @devmotion) - Rename `tilde_assume` (`tilde_observe`) and to `tilde_assume!` (`tilde_observe!`), and `tilde(rng, ...)` (`tilde(...)`) to `tilde_assume(rng, ...)` (`tilde_observe(...)`). - `tilde_assume!` simply calls `tilde_assume` followed by `acclogp(varinfo, result_from_tilde_assume)`, so the `!` here is to indicate that it's mutating the `logp` field in `VarInfo`. Co-authored-by: Hong Ge <hg344@cam.ac.uk>
This PR introduces a
abstract type WrappedContext <: AbstractContext
which contexts that wrap other contexts should subtype, e.g.MiniBatchContext
.We also introduce clear separation between sampling and evaluation.
Proposal (as of [2021-05-24])
abstract type WrappedContext{LeafCtx} <: AbstractContext
unwrap(ctx)
: recursively unwraps theWrappedContext
, eventually returning the "leaf-context".unwrappedtype(ctx)
: immediately returnsLeafCtx
fromWrappedContext{LeafCtx}
. This is useful as we cannot always rely on the Julia compiler to be able to infer the returntype ofunwrap
for deeper nested wrappings.childcontext(ctx)
: returns the child-context. Useful for traversing a nested wrapped contexts.rewrap(parent, leaf)
: this will essentially reconstructparent
withleaf
as the leaf-context.ctx isa MiniBatchContext{..., <:SamplingContext}
which is currently in "sampling-mode", and want to switch to evaluation-mode:rewrap(ctx, EvaluationContext())
.PriorContext
LikelihoodContext
MinibatchContext
abstract type PrimitiveContext <: AbstractContext
EvaluationContext
: only evaluation of the model is to take place, i.e. no sampling of variables.SamplingContext
: sampling and evaluation of model takes place.PrimitiveContext
should call the underlyingassume
/observe
methods.WrapperContext
should never do so.WrapperContext
to change between the "fundamental" modes, e.g. sampling and evaluation, byrewrap
-ing accordingly.PrimitiveContext
.~
implementation.~
statement is:tilde_assume
->tilde(rng, ...)
->_tilde(rng, ...)
->assume
tilde_observe
->tilde(...)
->_tilde(...)
->observe
dot_tilde_assume
anddot_tilde_observe
._tilde
is currently only used forNamedDist
to allow overriding the variable-name used for a particular~
statement._tilde
and handleNamedDist
before callingtilde_assume
, etc. by using aunpack_right_vns
(andunpack_right_left_vns
for dot-statements) (thanks to @devmotion)tilde_assume
(tilde_observe
) and totilde_assume!
(tilde_observe!
), andtilde(rng, ...)
(tilde(...)
) totilde_assume(rng, ...)
(tilde_observe(...)
).tilde_assume!
simply callstilde_assume
followed byacclogp(varinfo, result_from_tilde_assume)
, so the!
here is to indicate that it's mutating thelogp
field inVarInfo
.Things to consider
WrappedContext
andPrimitiveContext
, or should they all just beAbstractContext
?SamplingContext
to be an inner-most context as in this PR, or an outer-most context as in https://github.com/TuringLang/DynamicPPL.jl/compare/dw/samplingcontext ?SamplingContext
at every call totilde_assume
and then choose not to rewrap if we only want to evaluate.tilde_assume(..., ::SamplingContext, ...)
as described in WrappedContext and separation between sampling and evaluation #249 (comment).rewrap
.rng
andsampler
part ofSamplingContext
, as proposed in Make Sampler a field of DefaultContext (renamed to SamplingContext) #80 and implemented in https://github.com/TuringLang/DynamicPPL.jl/compare/dw/samplingcontext ?value
/left
argument totilde_assume
(nottilde_assume!
) andassume
, as indot_tilde_assume
, which is initialized tonothing
bytilde_assume!
.assume
in most cases does not have to touchVarInfo
at all, but instead just works with thevalue
passed to it ⟹ simplifiesassume
implementations.~
without touchingVarInfo
! Useful for several reasons:ConditionContext
would not need to mutateVarInfo
to do it's job, but instead just replacevalue::Nothing
with the value it wants.ConditionContext
with aNamedTuple
.ConditionContext
withNamedTuple
orComponentArray
, thus avoiding having to go throughVarInfo
⟹ improves performance (no need to pass throughreconstruct
norgetindex(::VarInfo, ...)
)tilde_assume(..., ctx::SamplingContext, ...)
callassume(rng, ...)
ifvalue::Nothing
, but callassume(...)
which only computeslogpdf
ifvalue
is notnothing
.ContextualModel
as explained in Generalization of abstract model functions AbstractPPL.jl#10 (comment).Motivation
Currently we have the following contexts in DynamicPPL:
DefaultContext
(should this be renamed toJointContext
?)PriorContext
LikelihoodContext
MiniBatchContext
PrefixContext
These fall into two categories:
tilde
withPriorContext
simply returns0
rather than calling_tilde(sampler, right, left, vi)
asDefaultContext
would. In this we have:DefaultContext
PriorContext
LikelihoodContext
tilde
ofMiniBatchContext
defers the actual computation to it's wrapped context by callingtilde(ctx.ctx, ...)
and then multiplies the result with a scalar. In this we have:MiniBatchContext
PrefixContext
AFAIK it's unlikely that we'll add much more to the first category, but for the second category we have lots of possibilities that are likely to make into DPPL at some point, e.g.
EvaluationContext
(#242) andConditionContext
(TuringLang/AbstractPPL.jl#10 (comment)). Therefore, having a nice way of dealing with nested wrapped contexts (i.e. nested contexts in the second category above) is useful. And without a nice way of handling this, one will quickly run into issues.An immediate issue is
EvaluationContext
. The idea ofEvalutionContext
is that we provide values for all the parameters used in the model, thus makingtilde_assume
"redundant". The implementation becomes very easy since all we have to do is overloadtilde_assume
to instead extract the corresponding value from the givenNamedTuple
rather than sample it, calltilde_observe
with this value, and finally return the extracted value:DynamicPPL.jl/src/context_implementations.jl
Lines 59 to 78 in d65337d
But, as you can see in the lines above, we need to be careful when we're handling special contexts, e.g.
PriorContext
, since callingtilde_observe
with aPriorContext
is a no-op which it shouldn't be when we're calling it fromtilde_assume
! Okay, so we make a check whether the child-context is aPriorContext
and all is good.WRONG! What if the child-context of
EvaluationContext
is aMiniBatchContext
?! The check will fail, we defer thetilde_observe
toMiniBatchContext
which in turn defers it toPriorContext
, resulting inlogp
of 0. What we really need to do is check the type of the leaf context, i.e. the result of "unwrapping" the wrapped contexts recursively. This is similar to a bunch of issues (e.g. dispatch forAdjoint
of aCuArray
) which the community has with "lack" of support for wrapped array-types (JuliaLang/julia#31563).So this PR addresses that by introducing a
WrappedContext
with the following methods:unwrap(ctx)
: recursively unwraps theWrappedContext
, eventually returning the "leaf-context".unwrappedtype(ctx)
: immediately returnsLeafCtx
fromWrappedContext{LeafCtx}
. This is useful as we cannot always rely on the Julia compiler to be able to infer the returntype ofunwrap
for deeper nested wrappings.childcontext(ctx)
: returns the child-context. Useful for traversing a nested wrapped contexts.rewrap(parent, leaf)
: this will essentially reconstructparent
withleaf
as the leaf-context.This for example solves the above issue for
EvaluationContext
since we then can do:Similar issues also shows up for other wrapped contexts, and thus will be sorted out through this.
NOTE: there's still an issue with
MiniBatchContext
in the above impl ofEvaluationContext
:MiniBatchContext
will apply the weighting to variables that weassume
. We could add afilter
torewrap
so that if any of theWrappedContext
in a nested wrapping are of thefiltertype
, then we don't include these in the rewrapped result. Is there a better solution to this?EDIT: To be a bit more specific about what I mean in my last "NOTE", we can do:
which would allow us to specify "hey, if you find any of these wrapped contexts along the way, please just skip them". E.g.
And then we also add this parameter to
rewrap
.EDIT 2: Also to clarify why we need it for
ContextualModel
in the linked comment above introducingConditionedContext
: in aContextualModel
we still want to allow the user to pass some other context if they so desire. To allow this we need the ability torewrap
.