Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WrappedContext and separation between sampling and evaluation #249

Closed
wants to merge 42 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented May 20, 2021

This PR introduces a abstract type WrappedContext <: AbstractContext which contexts that wrap other contexts should subtype, e.g. MiniBatchContext.

We also introduce clear separation between sampling and evaluation.

Proposal (as of [2021-05-24])

  • Additions
    • abstract type WrappedContext{LeafCtx} <: AbstractContext
      • Methods:
        • unwrap(ctx): recursively unwraps the WrappedContext, eventually returning the "leaf-context".
        • unwrappedtype(ctx): immediately returns LeafCtx from WrappedContext{LeafCtx}. This is useful as we cannot always rely on the Julia compiler to be able to infer the returntype of unwrap for deeper nested wrappings.
        • childcontext(ctx): returns the child-context. Useful for traversing a nested wrapped contexts.
        • rewrap(parent, leaf): this will essentially reconstruct parent with leaf as the leaf-context.
          • E.g. have a ctx isa MiniBatchContext{..., <:SamplingContext} which is currently in "sampling-mode", and want to switch to evaluation-mode: rewrap(ctx, EvaluationContext()).
      • Concrete subtypes:
        • PriorContext
        • LikelihoodContext
        • MinibatchContext
        • Etc.
    • abstract type PrimitiveContext <: AbstractContext
      • Concrete subtypes:
        • EvaluationContext: only evaluation of the model is to take place, i.e. no sampling of variables.
        • SamplingContext: sampling and evaluation of model takes place.
    • Notes:
      • Only PrimitiveContext should call the underlying assume/observe methods. WrapperContext should never do so.
      • Allows WrapperContext to change between the "fundamental" modes, e.g. sampling and evaluation, by rewrap-ing accordingly.
      • Precedence of application of contexts is "natural" in the sense that the inner context has precedence over outer context, with the highest precendence always given to PrimitiveContext.
  • Changes
    • Remove unnecessary complexity in ~ implementation.
      • Current calling hierarchy for a ~ statement is:
        • tilde_assume -> tilde(rng, ...) -> _tilde(rng, ...) -> assume
        • tilde_observe -> tilde(...) -> _tilde(...) -> observe
        • Similarly for dot_tilde_assume and dot_tilde_observe.
      • This is super-confusing and difficult to debug.
      • _tilde is currently only used for NamedDist to allow overriding the variable-name used for a particular ~ statement.
      • Propose the following changes:
        • Remove _tilde and handle NamedDist before calling tilde_assume, etc. by using a unpack_right_vns (and unpack_right_left_vns for dot-statements) (thanks to @devmotion)
        • Rename tilde_assume (tilde_observe) and to tilde_assume! (tilde_observe!), and tilde(rng, ...) (tilde(...)) to tilde_assume(rng, ...) (tilde_observe(...)).
          • tilde_assume! simply calls tilde_assume followed by acclogp(varinfo, result_from_tilde_assume), so the ! here is to indicate that it's mutating the logp field in VarInfo.
      • Choices of names is up for discussion!

Things to consider

  • Do we want the separation between WrappedContext and PrimitiveContext, or should they all just be AbstractContext?
  • Do we want SamplingContext to be an inner-most context as in this PR, or an outer-most context as in https://github.com/TuringLang/DynamicPPL.jl/compare/dw/samplingcontext ?
    • The difference can be boiled down to the choice between:
      1. Rewrap with SamplingContext at every call to tilde_assume and then choose not to rewrap if we only want to evaluate.
      2. Don't do anything until we want to switch from sampling-mode to evaluation-mode, or vice-versa, at which point we call rewrap.
  • Should we make rng and sampler part of SamplingContext, as proposed in Make Sampler a field of DefaultContext (renamed to SamplingContext) #80 and implemented in https://github.com/TuringLang/DynamicPPL.jl/compare/dw/samplingcontext ?
    • TOR: I'm personally in favour of this, as long as it doesn't have any unforseen consequences. It's easy enough to incorporate in the current PR.
      • [2021-05-31] This is now part of the PR.
  • Add additional value/left argument to tilde_assume (not tilde_assume!) and assume, as in dot_tilde_assume, which is initialized to nothing by tilde_assume!.
    • Motivation:
      • assume in most cases does not have to touch VarInfo at all, but instead just works with the value passed to it ⟹ simplifies assume implementations.
      • Allows overriding of values used in ~ without touching VarInfo! Useful for several reasons:
        • Something like a ConditionContext would not need to mutate VarInfo to do it's job, but instead just replace value::Nothing with the value it wants.
        • Could avoid linearization of variables by using (soon-to-come) ConditionContext with a NamedTuple.
        • Computing gradients can be done using a (soon-to-come) ConditionContext with NamedTuple or ComponentArray, thus avoiding having to go through VarInfo ⟹ improves performance (no need to pass through reconstruct nor getindex(::VarInfo, ...))
      • Allows fixing certain variables to values during sampling by having tilde_assume(..., ctx::SamplingContext, ...) call assume(rng, ...) if value::Nothing, but call assume(...) which only computes logpdf if value is not nothing.
        • ⟹ Can tell DPPL to not update certain variables when sampling.
  • Introduce ContextualModel as explained in Generalization of abstract model functions AbstractPPL.jl#10 (comment).
    • TOR: This might be best left for another PR.

Motivation

Currently we have the following contexts in DynamicPPL:

  • DefaultContext (should this be renamed to JointContext?)
  • PriorContext
  • LikelihoodContext
  • MiniBatchContext
  • PrefixContext
    These fall into two categories:
  1. Those which "shortcuts" the computation, e.g. the observe tilde with PriorContext simply returns 0 rather than calling _tilde(sampler, right, left, vi) as DefaultContext would. In this we have:
    • DefaultContext
    • PriorContext
    • LikelihoodContext
  2. Those which alters the result of the computation, e.g. the observe tilde of MiniBatchContext defers the actual computation to it's wrapped context by calling tilde(ctx.ctx, ...) and then multiplies the result with a scalar. In this we have:
    • MiniBatchContext
    • PrefixContext

AFAIK it's unlikely that we'll add much more to the first category, but for the second category we have lots of possibilities that are likely to make into DPPL at some point, e.g. EvaluationContext (#242) and ConditionContext (TuringLang/AbstractPPL.jl#10 (comment)). Therefore, having a nice way of dealing with nested wrapped contexts (i.e. nested contexts in the second category above) is useful. And without a nice way of handling this, one will quickly run into issues.

An immediate issue is EvaluationContext. The idea of EvalutionContext is that we provide values for all the parameters used in the model, thus making tilde_assume "redundant". The implementation becomes very easy since all we have to do is overload tilde_assume to instead extract the corresponding value from the given NamedTuple rather than sample it, call tilde_observe with this value, and finally return the extracted value:

function tilde_assume(
rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple}
)
value = _getindex(getfield(vi.θ, getsym(vn)), inds)
# Contexts which have different behavior between `assume` and `observe` we need
# to replace with `DefaultContext` here, otherwise the observation-only
# behavior will be applied to `assume`.
# FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`.
# This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`.
if ctx.ctx isa PriorContext
tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi)
elseif ctx.ctx isa LikelihoodContext
# Need to make it so that this isn't computed.
tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi)
else
tilde_observe(ctx, sampler, right, value, vn, inds, vi)
end
return value
end

But, as you can see in the lines above, we need to be careful when we're handling special contexts, e.g. PriorContext, since calling tilde_observe with a PriorContext is a no-op which it shouldn't be when we're calling it from tilde_assume! Okay, so we make a check whether the child-context is a PriorContext and all is good.

WRONG! What if the child-context of EvaluationContext is a MiniBatchContext?! The check will fail, we defer the tilde_observe to MiniBatchContext which in turn defers it to PriorContext, resulting in logp of 0. What we really need to do is check the type of the leaf context, i.e. the result of "unwrapping" the wrapped contexts recursively. This is similar to a bunch of issues (e.g. dispatch for Adjoint of a CuArray) which the community has with "lack" of support for wrapped array-types (JuliaLang/julia#31563).

So this PR addresses that by introducing a WrappedContext with the following methods:

  • unwrap(ctx): recursively unwraps the WrappedContext, eventually returning the "leaf-context".
  • unwrappedtype(ctx): immediately returns LeafCtx from WrappedContext{LeafCtx}. This is useful as we cannot always rely on the Julia compiler to be able to infer the returntype of unwrap for deeper nested wrappings.
  • childcontext(ctx): returns the child-context. Useful for traversing a nested wrapped contexts.
  • rewrap(parent, leaf): this will essentially reconstruct parent with leaf as the leaf-context.

This for example solves the above issue for EvaluationContext since we then can do:

function tilde_assume!(
    rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple}
)
    value = _getindex(getfield(vi.θ, getsym(vn)), inds)

    # Contexts which have different behavior between `assume` and `observe` we need
    # to replace with `DefaultContext` here, otherwise the observation-only
    # behavior will be applied to `assume`.
    new_ctx = if unwrappedtype(ctx) <: PriorContext
        rewrap(ctx.ctx, LikelihoodContext(unwrap(ctx).vars))
    elseif unwrappedtype(ctx) <: LikelihoodContext
        # Need to make it so that this isn't computed.
        rewrap(ctx.ctx, PriorContext(unwrap(ctx).vars))
    else
        ctx.ctx
    end
    tilde_observe(new_ctx, sampler, right, value, vn, inds, vi)
    return value
end

Similar issues also shows up for other wrapped contexts, and thus will be sorted out through this.

NOTE: there's still an issue with MiniBatchContext in the above impl of EvaluationContext: MiniBatchContext will apply the weighting to variables that we assume. We could add a filter to rewrap so that if any of the WrappedContext in a nested wrapping are of the filtertype, then we don't include these in the rewrapped result. Is there a better solution to this?

EDIT: To be a bit more specific about what I mean in my last "NOTE", we can do:

childcontext(ctx::AbstractContext) = childcontext(ctx, Union{})
childcontext(ctx::AbstractContext, filtertype::Type) = nothing
function childcontext(ctx::WrappedContext, filtertype::Type)
    # If `child` is of `filtertype`, then we recurse futher.
    if ctx.ctx isa filtertype
        return childcontext(ctx.ctx, filtertype)
    else
        return ctx.ctx
    end
end

which would allow us to specify "hey, if you find any of these wrapped contexts along the way, please just skip them". E.g.

ctx = PrefixContext{:c}(MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}())))
DynamicPPL.childcontext(ctx)
# MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}(1.0, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}(DefaultContext()))
DynamicPPL.childcontext(ctx, MiniBatchContext)
# PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}(DefaultContext())

And then we also add this parameter to rewrap.

EDIT 2: Also to clarify why we need it for ContextualModel in the linked comment above introducing ConditionedContext: in a ContextualModel we still want to allow the user to pass some other context if they so desire. To allow this we need the ability to rewrap.

@torfjelde torfjelde requested a review from devmotion May 20, 2021 07:25
@devmotion
Copy link
Member

DefaultContext (should this be renamed to JointContext?)

IIRC it is not necessarily the log joint, some samplers require a special log density evaluation by default (which also doesn't belong to any of the other categories). Probably this could be solved though with a SamplingContext, as suggested before.

What we really need to do is check the type of the leaf context, i.e. the result of "unwrapping" the wrapped contexts recursively.

This seems reasonable but I wonder: Do we actually need another context? Isn't it sufficient to have an interface that allows to unwrap etc. and implement it for the existing (and possibly new) contexts?

@devmotion
Copy link
Member

BTW I think also with unwrap the checks should not be hard-coded in the function body of the EvaluationContext but one should just call another function that takes the unwrapped context as an argument and allows to dispatch on its type.

@devmotion
Copy link
Member

Another observation: it also seems that one of the main problems with the EvaluationContext is that PriorContext etc. mix both sampling and evaluation. So maybe the correct/better/additional fix would be to introduce both a SamplingContext and an EvaluationContext where the former always replaces existing variables and the latter only performs evaluation. And then one could use PriorContext etc. in both to specify how the log densities are accumulated. It seems in this way one maybe can avoid calling tilde_observe in an tilde_assume statement (which seems wrong, regardless of all other problems).

@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

IIRC it is not necessarily the log joint, some samplers require a special log density evaluation by default (which also doesn't belong to any of the other categories). Probably this could be solved though with a SamplingContext, as suggested before.

Ah, good point 👍

This seems reasonable but I wonder: Do we actually need another context? Isn't it sufficient to have an interface that allows to unwrap etc. and implement it for the existing (and possibly new) contexts?

Yes because you might want to dispatch based on the LeafCtx. If you just implement rewrap, etc. there are two issues:

  1. We need to use if typeof(unwrap(ctx)) statements rather than multiple dispatch.
  2. This is not guaranteed to be inferrable by the compiler (in contrast to dispatch on LeafCtx). This is the biggest issue.

Another observation: it also seems that one of the main problems with the EvaluationContext is that PriorContext etc. mix both sampling and evaluation. So maybe the correct/better/additional fix would be to introduce both a SamplingContext and an EvaluationContext where the former always replaces existing variables and the latter only performs evaluation. And then one could use PriorContext etc. in both to specify how the log densities are accumulated. It seems in this way one maybe can avoid calling tilde_observe in an tilde_assume statement (which seems wrong, regardless of all other problems).

I like this. I was thinking that the EvaluationContext would be a step in this direction, but I think you're right: doing it right away might solve these issues, i.e. not needing to call observe from assume (and yes I agree it's bad, haha; I'm just lazy 🙃 )

I'll make an attempt at adding SampleContext and EvaluationContext to this PR. IMO those also benefit from the WrappedContext stuff.

@devmotion
Copy link
Member

We need to use if typeof(unwrap(ctx)) statements rather than multiple dispatch.

Do we? I thought one would just call something like do_actually_something(unwrap(ctx), all_other_args...) and then dispatch on the unwrapped context. As I wrote somewhere in my opinion it would be good to avoid the hardcoded type checks anyway.

@torfjelde
Copy link
Member Author

Do we?

Yeah okay fair, if introduce a bunch of _method_that_does_the_same_thing_but_is_private we don't have to (I'm maaaybe exaggerating here 🙃 ), but we still have the type-inferrability issue though which is the main one 😕

@devmotion
Copy link
Member

Do you have an example where type inference breaks? I would have assumed that it's not a problem for the compiler, at least not with reasonably deep contexts (eg there are a lot of recursive definitions in SciML which work fine).

@torfjelde
Copy link
Member Author

Do you have an example where type inference breaks? I would have assumed that it's not a problem for the compiler, at least not with reasonably deep contexts (eg there are a lot of recursive definitions in SciML which work fine).

I don't have an example right now, but I can send you some pictures of my battlescars from implementing compositions in Bijectors.jl!
But maybe you're right. I seem to remember recursions of depth ~20 would break down; maybe that's not a realistic scenario here?

I don't like not adding it to the type though, as it's more difficult to add back later if we realize we need vs. removing it if we realize we don't need it. But I agree, we shouldn't do it without reason. I'll do some checks 👍

@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

I'll do some checks

Seems like our if @generated in constructor of PriorContext was never hit before and instead we were always using the recursive version; it even has a bug in it that we never saw because it was never hit (because we didn't try deep enough nesting; in fact our tests has a case that is 3 levels deep I think, and I the issue with 4 it seems).

But if you remove the if @generated so that we only have the recursive definition, we get the following:

julia> @code_warntype PrefixContext{:c}(MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}())))
Variables
  #self#::Type{PrefixContext{:c, C, LeafCtx} where {C, LeafCtx}}
  ctx::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}

Body::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}
1%1 = $(Expr(:static_parameter, 1))::Core.Const(:c)
│   %2 = DynamicPPL.typeof(ctx)::Core.Const(MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext})
│   %3 = Core.apply_type(DynamicPPL.PrefixContext, %1, %2, $(Expr(:static_parameter, 2)))::Core.Const(PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext})
│   %4 = Core.fieldtype(%3, 1)::Core.Const(MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext})
│   %5 = Base.convert(%4, ctx)::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}%6 = %new(%3, %5)::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}
└──      return %6

julia> @code_warntype PrefixContext{:d}(PrefixContext{:c}(MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}()))))
Variables
  #self#::Type{PrefixContext{:d, C, LeafCtx} where {C, LeafCtx}}
  ctx::PrefixContext{:c, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}

Body::PrefixContext{_A, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext} where _A
1%1 = DynamicPPL.Symbol($(Expr(:static_parameter, 2)), DynamicPPL.PREFIX_SEPARATOR, $(Expr(:static_parameter, 1)))::Symbol%2 = Core.apply_type(DynamicPPL.PrefixContext, %1)::Type{PrefixContext{_A, C, LeafCtx} where {C, LeafCtx}} where _A
│   %3 = Base.getproperty(ctx, :ctx)::MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}%4 = (%2)(%3)::PrefixContext{_A, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext} where _A
└──      return %4

That is not a lot of nesting before it fails 😕

@torfjelde
Copy link
Member Author

But for unwrap it seems to do very well actually:

julia> ctx = MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, MiniBatchContext(1.0, PrefixContext{:a}(PrefixContext{:b}())))))))))))));

julia> @code_warntype DynamicPPL.unwrap(ctx)
Variables
  #self#::Core.Const(DynamicPPL.unwrap)
  ctx::MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}

Body::DefaultContext
1%1 = Base.getproperty(ctx, :ctx)::MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, MiniBatchContext{Float64, PrefixContext{Symbol("b.a"), DefaultContext, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}, DefaultContext}
│   %2 = DynamicPPL.unwrap(%1)::Core.Const(DefaultContext())
└──      return %2

Hmmmmmmm. Not sure what to think.

@devmotion
Copy link
Member

devmotion commented May 20, 2021

That is not a lot of nesting before it fails confused

If the bug is fixed, both examples infer correctly it seems (with an additional DefaultContext in the inner-most layer). Maybe the compiler figures out that the generated part fails and then falls back to the recursive definition?

Edit: Even the MiniBatch example infers correctly.

@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

Edit: Even the MiniBatch example infers correctly.

Yeah, because I guess it starts using @generated when the recursive can't be inferred anymore?

EDIT: Just for the record, in the example where it failed I explicitly removed the if @generated and only used the recursive def to demonstrate the limitation of the recursive definition.

@torfjelde
Copy link
Member Author

But even so, isn't having LeafCtx in the type much easier on the compiler?

@devmotion
Copy link
Member

Probably in some extreme cases but I'm not sure if this is relevant (already the examples above seem a bit constructed). Moreover, the compiler has to specialize all methods for different leaf types, so possibly there is a lot of unnecessary recompilation due to the additional type parameters whereas actually one would need specialize on the leaf type only in certain functions.

@devmotion
Copy link
Member

Another argument for not having a WrapperContext is the fact that there is no multiple inheritance in Julia, so it limits possible type structures and relations. In general, I just have the feeling that this is more an interface or trait rather than a type thing.

@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

Moreover, the compiler has to specialize all methods for different leaf types, so possibly there is a lot of unnecessary recompilation due to the additional type parameters whereas actually one would need specialize on the leaf type only in certain functions.

Hmm, but we're already specializing on the child-context so at least shouldn't make things harder? Also there are cases where we don't actually need to unwrap, but instead only want unwrappedtype in which case you wouldn't need to compile some methods.

Another argument for not having a WrapperContext is the fact that there is no multiple inheritance in Julia, so it limits possible type structures and relations. In general, I just have the feeling that this is more an interface or trait rather than a type thing.

See I'm kind of the opposite opinion because generally how you work with a WrappedContext vs. any AbstractContext is different: in the former you almost always just defer the computation to unwrap while in the latter you don't. E.g. you can do

f(ctx::AbstractContext) = # actual computation ...
f(ctx::WrappedContext) = f(unwrap(ctx))

without worrying about StackOverflow, AmbuigityError, etc. How would you do this if you don't have a separation between WrappedContext and just any other AbstractContext?

@devmotion
Copy link
Member

Hmm, but we're already specializing on the child-context so at least shouldn't make things harder?

While it's true that it amount of methods that have to be compiled does not increase if the information is redundant, additional type parameters still put more stress on the compiler. It has to be carried around everywhere and, as far as I know, the compiler doesn't realize that it does not increase the information content.

Also there are cases where we don't actually need to unwrap, but instead only want unwrappedtype in which case you wouldn't need to compile some methods.

I'm not completely sure if I understand you correctly, but it seems this would just motivate to have both unwrap and unwrappedtype as part of the interface, possibly with unwrappedtype(context) = typeof(unwrap(context)) as a fallback.

How would you do this if you don't have a separation between WrappedContext and just any other AbstractContext?

You just specialize f on the type of interest or possibly on a Union of context types. It's not necessary that they share the same super type to be able to add this separation and to use unwrap.

@devmotion
Copy link
Member

devmotion commented May 20, 2021

Regarding the last point, one can even add a trait iswrappercontext(::AbstractContext) and use it in the definition of f if one wants to provide an easy default for all wrapped type that might only be defined at a later point.

That's common practice e.g. for iterators (HasEltype etc.).

@torfjelde
Copy link
Member Author

torfjelde commented May 20, 2021

While it's true that it amount of methods that have to be compiled does not increase if the information is redundant, additional type parameters still put more stress on the compiler. It has to be carried around everywhere and, as far as I know, the compiler doesn't realize that it does not increase the information content.

But say I implement a method for AbstractArray{<:Real}, the compiler doesn't need to specialize on the N parameter, right? And this by itself makes life easier for the compiler, no? I get what you're saying, but it seems "intuitive" that it's worth it because carrying around that information allows you to take these shortcuts.

possibly with unwrappedtype(context) = typeof(unwrap(context)) as a fallback.

IMHO it wouldn't be a fallback; it would be the implementation. Other than simple rules for 1-level nesting, there's no other good way to implement this on a case-by-case basis for nested WrappedContext, right? And this is what I mean: in the implementation

unwrappedtype(context) = typeof(unwrap(context))

you need to compile unwrap for typeof(context), which can be a deeply nested context which the compiler then needs to specialize on. Here the compiler needs to specialize on this particular sequence of nested contexts, of which there are faaar more than just numbe of contexts (in particular, number of non-wrapped contexts). In constrast, the current impl of unwrappedtype does not require specialization on each of these typeof(context); it only requires specialization on the LeafCtx parameteric type (for which there will be very few; the larger group will be the WrapperContext).

Also, just to note: we're not really putting any restrictions on the relationships between the contexts here. We can always just specialize the methods for any particular AbstractContext if need be. Though I agree, if we start having a large number of contexts which does not fit into either WrappedContext nor "LeafCtx" (not actually a type, but you get the idea), then the structure loses it's use. But as far as I can tell, this is not going to be the case. Instead we'll have a lot of WrappedContext and a couple leaf-contexts.

You just specialize f on the type of interest or possibly on a Union of context types. It's not necessary that they share the same super type to be able to add this separation and to use unwrap.

Yeaaah but this can quickly become unwieldy though + a user can't add to the union => they need to manually overload that method for their particular WrappedContext rather than just being able to use the default impl.

Regarding the last point, one can even add a trait iswrappercontext(::AbstractContext) and use it in the definition of f if one wants to provide an easy default for all wrapped type that might only be defined at a later point

Sure! But as you've pointed out, dispatch is preferable to if statements:) Or do mean something that should return an IsWrapped type and then we dispatch on this? I still don't like this. To me it seems like we're adding unnecessary complexity, i.e. increasing number of methods by a factor of 2, for no gain whatsoever 😕

Also there is discussion regarding these two approaches in JuliaLang/julia#31563. Of course there's probably going to be greater latency with overloads of arrays, but this comment (JuliaLang/julia#31563 (comment)) indicates that the "intuition" that a WrappedContext would improve compilation seems to be valid?

@devmotion
Copy link
Member

Haha I see you're not convinced by Julia's (sometimes a bit excessive, I guess) use of traits 😄

Sure! But as you've pointed out, dispatch is preferable to if statements:) Or do mean something that should return an IsWrapped type and then we dispatch on this? I still don't like this.

I think one has to separate two things here:

I don't like the pattern

function f(...)
    ...
    y = g(...)
    if y isa SomeType
        ...
    elseif y isa SomeOtherType
        ...
    else
        ...
    end
    ...
end

since one has to reimplement the complete function f if one adds some additional behaviour (i.e., a different return type) to g. This seems bad and is sometimes not even possible if it is not possible to specialize on the arguments of f in such a way that it covers all cases where the new behaviour of g would be triggered. If instead one uses

function f(...)
    ...
    y = g(...)
    z = h(y, ...) # or just h(y, ...)
    ...
end

then one has to implement only h for the new type of y without having to reimplement the whole function f.

However, in the wrapper example

f(context::AbstractContext) = ...
f(context::WrapperContext) = ...

the situation is different. As mentioned, of course, one can always specialize f on the type of the context, even if there is no WrapperContext supertype. But then it is difficult for new wrapper types to opt in and get the wrapper behaviour for free. So an idiomatic Julia approach (as done e.g. in the iterator interface) would be to define

struct IsWrapper end
struct IsNoWrapper end

iswrapper(context::AbstractContext) = IsNoWrapper()
iswrapper(context::PrefixContext) = IsWrapper()

f(context::AbstractContext) = f(context, iswrapper(context))
function f(context::AbstractContext, ::IsWrapper)
    ....
end
function f(context::AbstractContext, ::IsNoWrapper)
    ...
end

Then new contexts can declare if they are a wrapper or not and then benefit from the default implementation, and, of course, it is still possible to just implement f(::MyContext) directly. It can be a bit cumbersome to introduce these new singleton types and so alternatively one could just write

iswrapper(context::AbstractContext) = false
iswrapper(context::PrefixContext) = true

f(context::AbstractContext) = f(context, Val{iswrapper(context)}())
function f(context::AbstractContext, iswrapper::Val{false})
    ....
end
function f(context::AbstractContext, iswrapper::Val{true})
    ...
end

And if there are really only these two possibilities - either it is a wrapper or not - then this can be shortened to

iswrapper(context::AbstractContext) = false
iswrapper(context::PrefixContext) = true

function f(context::AbstractContext)
    if iswrapper(context)
        ...
    else
        ...
    end
end

The compiler will elide the unused branch automatically, so this is equivalent to the much more verbose trait implementation above.

The main different between this if statement and the one in the example above that I do not like is that in the trait-approach it is easy to add support for new types whereas above you have to reimplement f. Of course, one has to use the slightly more verbose trait-dispatch-approach above if one wants to allow users to add different return values (or types) to iswrapper.

To me it seems like we're adding unnecessary complexity, i.e. increasing number of methods by a factor of 2, for no gain whatsoever confused

The main advantage is that you do not enforce super types. Imagine there is some other orthogonal behaviour that should be shared among different contexts such as if they sample values or just evaluate them. Then you can't encode both the wrapper behaviour and this other behaviour in the type hierarchy and so you have to decide for one of the default implementations and reimplement the other one for your context. Instead, if you use functions such as iswrapper or unwrap or unwrappedtype to customize the behaviour of contexts, it is easy to combine these with other traits. So it can actually help to reduce the amount of code that you have to write.

Additionally, I'd argue that in the iswrapper example, in particular in the short version, not many additional methods are needed - just one iswrapper definition instead of the <: WrapperContext supertype definition.

@torfjelde
Copy link
Member Author

torfjelde commented May 21, 2021

Whelp, I can't say no to the interface-approach now that you've written such a long and nice explanation 😞

But I'm with you on the pattern and am also aware of it (though it was still useful to read what you wrote above; very nicely laid out all the different approaches). My point is that I want the separation between any other context and a WrapperContext because I struggle to see how it makes sense to have one that is actually is both. BUT I'm also a sucker for generalization, so I'm fine with doing what you suggested:)

Another question, do you have any concrete ideas on how things should look internally with a SampleContext and EvaluateContext? To me there seems like there are a couple of approaches:

  1. Change compiler to be aware of SampleContext and EvaluateContext, e.g. when using EvaluateContext we don't need to return the value from tilde_assume but instead could just make tilde_assume accumulate the logpdf (which for example would make it possible to use Zygote even when when the original model impl is mutating).
  2. Don't make changes to compiler, but instead only change what is done internally in tilde_assume.

The issue that I'm struggling a bit to resolev right now is that assume is what mutates the VarInfo, thus this is what should change if we introduce a difference between SampleContext and EvaluateContext, but assume doesn't have access to the context! So what to do?

Immediate idea of solution is to introduce assume_logdensity and assume_sample or something, but this also seems "weird" plus would be very breaking. EDIT: Maybe not so breaking after all. I mean, it would be breaking, but most samplers doesn't use assume even so might be easy to fix where needed.

Also, I'm guessing we're going to leave assume as is i.e. it still returns both the value and the logpdf, and not introduce two different assume (one for computing logpdf and one for sampling), right?

@devmotion
Copy link
Member

Just a quick comment:

but assume doesn't have access to the context!

IIRC this is only the case since we drop the context when forwarding tilde_assume (or was it tilde?) to assume.

@torfjelde
Copy link
Member Author

Just a quick comment:

but assume doesn't have access to the context!

IIRC this is only the case since we drop the context when forwarding tilde_assume (or was it tilde?) to assume.

Yes, this is because we drop it when we forward to _tilde which it seems is used to get custom behavior for certain distributions, e.g. NamedDistribution, which then calls assume. But I think it should stay that way so that devs of samplers don't have to touch the context, no?

@torfjelde
Copy link
Member Author

Seems better to just let them work with two different versions of assume (say, one with. rng and one without?)

@torfjelde
Copy link
Member Author

torfjelde commented May 21, 2021

Also, say we overload _tilde to also take the context so we can then decide whether to call assume(...) or assume(rng, ...); then we need to ensure that SampleContext and EvaluteContext always is the inner-most context, i.e. PriorContext, LikelihoodContext, etc. needs to now become wrapped-contexts too and only SampleContext and EvaluateContext are allowed to be leaves.

@devmotion
Copy link
Member

I wonder, do we need a dedicated evaluation context at all? Could just by default every context be an evaluation context, and if you want to sample you have to combine it with a sampler in a SampleContext? Initially, this could even be completely internal without changing the user-facing Model API if we just wrap the provided sampler and context in a SamplingContext. This would also make it easy to ensure that SamplingContext is always the outer-most context.

@torfjelde
Copy link
Member Author

I think you might like the following improvement in my proposal

I do like this! But this is like a more complicated rewrap method, no? I guess my question is then: what's the argument of these "primitive" contexts, i.e. SamplingContext and EvaluationContext, being outer-most context rather than inner-most?

I prefer inner-most because:

  • Simpler rewrap/reconstruct functionality.
  • Make difference between the primitive contexts and the rest of the contexts clear, e.g. we'll throw an argument error if the user tries to construct a context with the leaf-context not being either SamplingContext or EvaluationContext while in the outer-most approach, this (I think) would have to be included in the model-definition (e.g. a check that the outer-most context is a "primitive" context) + it's not possible to make a context without ensuring that eventually unwraps to a primitive context.
  • It seems more "natural" to always have the last thing being applied being the "primitive" context (obviously subjective).

So it seems the major (only?) point that you might be still be unsatisfied with would be the addlogp! calls that one has to include when implementing tilde_assume etc?

I'm also unsatisfied with these being the outer-most context rather than inner-most. But yes, the above too.

In my opinion, it is simpler and bugs should be easier to spot - at most tilde_assume etc. or assume is incorrect, otherwise any part of the pipeline could be broken.

IMO this is a bit unfair though. You're technically not wrong, but in reality, essentially no contexts will actually touch tilde_assume, and thus as a result, the bug will almost always be in tilde (or w/e we rename to).

Similarly for primitive contexts, the bug will always be in tilde_primitive, not in tilde or tilde_assume.

So downstream packages would not have to worry about acclogp!

You mean if they only want to implement samplers? Sure, but it's the case if you want to implement contexts. I think we really should allow custom contexts to be implemented "easily" (though we should still be careful of adding them to DPPL), as there are some really neat stuff you could do with it.

I still wonder if it would be possible to include test utilities for contexts and other parts of DynamicPPL that can be used to check implementations more easily, both in DynamicPPL and downstream packages

I mean, I like this idea irregardless of this PR and our discussions!:)

I don't think the proposal is less flexible, at least I haven't encountered anything that does not work anymore so far.

Crap, I meant to changes this but clearly forgot to; sorry. You're right, it's not "less flexible" in what you can do with it; of course, if you're overloading tilde_assume you can do whatever you want. What I meant was more like "less nimble", i.e. making changes and adding new behaviour is more difficult and requires more effort.

Seperate thing

Btw, another question I would like to hear your thoughts on, independent of the ongoing tilde_assume vs. tilde_assume, tilde, _tilde: should we allow overriding values used without mutating VarInfo? I'm thinking (and have currently implemented) the following:

  • EvaluationContext: will not make any changes to VarInfo (other than logp which unfortunately has to change, AFAIK), but it can still allow replacing values using different contexts, e.g. ConditionContext.
  • SamplingContext: mutates the variables in VarInfo.

This is achieved as follows:

  • Add a value or left argument to tilde, or in your case, tilde_assume (this already exists for dot_tilde_assume).
    • By default this is nothing, but every subsequent call to tilde / tilde_assume is allowed to replace this with a value, with the last change finally reaching assume.
      • In my implementation, this is initialized to nothing in tilde_assume so that in every subsequent call to tilde this might be changed depending on the a context. This means that the inner-most context takes the highest precedence.
  • assume has two different implementations for the two different "primitive" contexts; one for each of the cases where value or left is nothing and not noting.
    • EvaluationContext:
      • nothing: extracts the variable from VarInfo and passes this to assume as the value.
      • not nothing: passes this value to assume as the value.
    • SamplingContext:
      • nothing: samples a new value and updates VarInfo accordingly, passing this new value to assume as the value.
      • not nothing: updates VarInfo to now hold the passed in value, and passes the same value to assume as the value.
  • NOTE: we can either have assume with one less argument indicating that it should be sampled, or we can just dispatch on left::Nothing. I think I prefer the former (though the current impl has the latter).

This has a couple of nice features IMO:

  • Can completely avoid the machinery VarInfo since the value passed will only hit the logpdf computation and that's it. E.g. taking the gradient wrt. θ in model(varinfo, sampler, Prior(θ)) becomes much cheaper. Full gradients can then be implemented very efficiently by simply using a ConditionContext when we eventually get that.
  • Most implementers of assume that only requires evaluation of the model now never even have to touch the VarInfo since the value will always be provided as an argument!

@devmotion
Copy link
Member

I do like this! But this is like a more complicated rewrap method, no? I guess my question is then: what's the argument of these "primitive" contexts, i.e. SamplingContext and EvaluationContext, being outer-most context rather than inner-most?

I prefer inner-most because:

* Simpler `rewrap`/`reconstruct` functionality.

* Make difference between the primitive contexts and the rest of the contexts clear, e.g. we'll throw an argument error if the user tries to construct a context with the leaf-context not being either `SamplingContext` or `EvaluationContext` while in the outer-most approach, this (I think) would have to be included in the model-definition (e.g. a check that the outer-most context is a "primitive" context) + it's not possible to make a context without ensuring that eventually unwraps to a primitive context.

* It seems more "natural" to always have the last thing being applied being the "primitive" context (obviously subjective).

I think the order of how the contexts are applied isn't significantly different in the proposals, is it?

  • LikelihoodContext etc. (what you grouped as EvaluationContext) are primitive contexts and can't wrap any other contexts.
  • SamplingContext is just an internal wrapper that takes an RNG, sampler, and context and puts them into one context. If the wrapped context is a non-primitive context itself, the wrapped context is applied after the SamplingContext with the child context of the wrapped context - recursively, until the SamplingContext contains only a primitive context such as LikelihoodContext (the inner-most context).
  • Other wrappers such as MinibatchContext or PrefixContext usually apply the child context first (possibly with some modifications such as with additional prefixes) and then perform additional work if needed (e.g. scale the loglikelihood in the MinibatchContext).

Could you explain what you mean with "Simpler rewrap/reconstruct functionality"? I just took two of the functions you proposed above (childcontext and rewrap) and put them into one function since it seemed they belong together, similar to ParameterHandling.flatten or Functors.functor. But the functionality is the same as you suggested.

Make difference between the primitive contexts and the rest of the contexts clear, e.g. we'll throw an argument error if the user tries to construct a context with the leaf-context not being either SamplingContext or EvaluationContext while in the outer-most approach, this (I think) would have to be included in the model-definition (e.g. a check that the outer-most context is a "primitive" context) + it's not possible to make a context without ensuring that eventually unwraps to a primitive context.

I don't think any error checking is needed at construction time. SamplingContexts are only constructed internally anyway and the implementation ensures that they are always applied first, as if they would be a leaf context. And since implementation ensures that the primitive contexts are always applied first if they are wrapped any weird or incorrectly context setup will just error when it is applied. Maybe I miss something, but it seems sufficient if every context ensures that it is doing the correct thing and applies possible child contexts in the correct order when actually performing the computations.

@torfjelde
Copy link
Member Author

torfjelde commented May 24, 2021

I think the order of how the contexts are applied isn't significantly different in the proposals, is it?

  • LikelihoodContext etc. (what you grouped as EvaluationContext) are primitive contexts and can't wrap any other contexts.
  • SamplingContext is just an internal wrapper that takes an RNG, sampler, and context and puts them into one context. If the wrapped context is a non-primitive context itself, the wrapped context is applied after the SamplingContext with the child context of the wrapped context - recursively, until the SamplingContext contains only a primitive context such as LikelihoodContext (the inner-most context).
  • Other wrappers such as MinibatchContext or PrefixContext usually apply the child context first (possibly with some modifications such as with additional prefixes) and then perform additional work if needed (e.g. scale the loglikelihood in the MinibatchContext).

"wrapped context is applied after the SamplingContext with the child context of the wrapped context"; this is different, no? In what I suggest all contexts are applied before SamplingContext.
BUT I believe this wouldn't matter for functionality (maybe this is what you meant actually? I.e. precedence of application is the same), e.g. in ConditionContext we want to intercept SamplingContext and convert into EvaluationContext for certain variables. If ConditionContext is part of a larger chain of wrapped contexts, this is achieved in the two different implementations as:

  • My proposal: defer to child-context and call rewrap only when ConditionContext is hit.
  • Your proposal: rewrap with SamplingContext at every step, but when we hit ConditionContext we stop rewrapping child-contexts in SamplingContext?

So same functionality, but IMO the first one seems easier/more intuitive.

Also, how does the wrapper-contexts know if they're sampling or not? I.e. how can a wrapper-context say "stop wrapping this in SamplingContext"? Custom implementation of unwrap_childcontext? In my proposal (it's not "my" proposal btw, there's a lot of ideas from you in there too, I just use it now because it's annoying to write "the proposal in this PR") the wrapper-context just changes this behavior by calling rewrap once.

Btw, where did the acclogp! go?

Could you explain what you mean with "Simpler rewrap/reconstruct functionality"? I just took two of the functions you proposed above (childcontext and rewrap) and put them into one function since it seemed they belong together, similar to ParameterHandling.flatten or Functors.functor. But the functionality is the same as you suggested.

IMO childcontext would find it's place even in your implementation, as it's just a convenient way of extracting the wrapped context. My implementation could also do away with the childcontext and instead put this in rewrap, but I made it separate because it seemed conceptually nicer to me.

And yeah, functionality is the same. But do you really think that rewrapping at every step along the way and then potentially stopping at some later point using the Functors.jl approach is easier to follow/more intuitive than "change leaf, i.e. call rewrap, if you want to change between sampling, evaluation, etc."? What if we want another context that does a similar thing to SamplingContext? If we did the same approach and someone decides to use both together, things just break, right? In contrast, if we made them PrimitiveContext this would not even be possible.

Btw, for the record: we can remove the _tilde or tilde_primitive by just making the impl of tilde for EvaluationContext and SamplingContext the impl of their current tilde_primitive.
And then, as I mentioned before, we can remove tilde if we want to write out acclogp! all the time. So we could get down to the same number of methods that you suggest without issues. I'm not entirely certain where I stand on this point, e.g. I kind of like to see tilde_primitive in my error vs. tilde with a "primitive" context, but I also agree that it's annoying with more methods. So I'm very open to changing the number of methods used.

I don't think any error checking is needed at construction time. SamplingContexts are only constructed internally anyway and the implementation ensures that they are always applied first, as if they would be a leaf context.

But like, why though? This doesn't seem like a good argument when there is an approach that could make it so that it never happens, right?

And since implementation ensures that the primitive contexts are always applied first if they are wrapped any weird or incorrectly context setup will just error when it is applied. Maybe I miss something, but it seems sufficient if every context ensures that it is doing the correct thing and applies possible child contexts in the correct order when actually performing the computations.

But bugs will happen, even if we only use it internally. So if there's a way to ensure that a particular error never occurs, that seems nice, no?

@torfjelde
Copy link
Member Author

Yo! What about renaming tilde_assume to tilde_assume! and then tilde to tilde_assume? It's kind of cheating because we're not really respecting the Bang-convention in the rest of the repo, but it sort of makes sense given that in one we mutate the logp and in the other we dont.

@torfjelde torfjelde changed the title WrappedContext WrappedContext and separation between sampling and evaluation May 25, 2021
src/context_implementations.jl Outdated Show resolved Hide resolved
src/context_implementations.jl Outdated Show resolved Hide resolved
src/context_implementations/minibatch.jl Outdated Show resolved Hide resolved
src/context_implementations/minibatch.jl Outdated Show resolved Hide resolved
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work - I left many comments below.


Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value.

Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
"""
function tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi)
function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice improvement!

"""
tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value.

Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be updated?

Suggested change
Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`.
Falls back to `tilde_assume(rng, ctx, sampler, right, vn, inds, vi)`.

function _tilde(rng, sampler, right, vn::VarName, vi)
return assume(rng, sampler, right, vn, vi)
function tilde_assume(rng, ctx::SamplingContext, sampler, right, left, vn, inds, vi)
return assume(rng, sampler, right, left, vn, inds, vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a good idea to wrap sampler and rng into SamplingContext. It makes function signature more consistent, and code clean.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to move sampler and rng into SamplingContext yeah, but we still want to unwrap those arguments before reaching assume. We don't want developers outside of DPPL to have to use contexts; contexts should be something they never really see. Contexts should also never change the behavior of assume.

function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi)
return _tilde(rng, sampler, right.dist, right.name, vi)
function tilde_assume(
rng, ctx::EvaluationContext, sampler, right, left::Nothing, vn, inds, vi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, we can wrap sampler into EvaluationContext or remove it (if it's not needed).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we end up putting the sampler and rng into SamplingContext, I agree 👍

For removing it: def something we want to do at some point, but right now it's a bit difficult I think.

function tilde_observe(
ctx::Union{SamplingContext,EvaluationContext}, sampler, right, left, vi
)
return observe(sampler, right, left, vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to assume, might be good to wrap sampler into ctx, and pass ctx through the call stack instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't want to pass ctx to observe nor assume. Agree with simplifying tilde_observe and tilde_observe! by removing sampler from the signature though.

@@ -0,0 +1,54 @@
function tilde_assume(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe split the code into multiple files in future PRs, so that we can keep track of actual code changes in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I agree with you, but I did this because it was already starting to look pretty nasty in the review-section 😕 So many changes to the diffs didn't "match" up.

@@ -4,80 +4,49 @@
The `DefaultContext` is used by default to compute log the joint probability of the data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have a DefaultContext.

include("contexts/prior.jl")
include("contexts/likelihood.jl")
include("contexts/minibatch.jl")
include("contexts/prefix.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two sub folders for contexts related code at the moment, i.e. contexts/ and contexts_implementaiton. I suggest that we merge these two sub folders into one contexts/. Also, we should also merge contexts/prior.jl and context_implementations/prior.jl into one file contexts/prior.jl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of inclusion order, e.g. we need the definition of some of the contexts in model.jl but at that point we cannot yet define impls of tilde_assume, etc., so that's why we have this split.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could possibly be simplified though! Just figured I'd stick with what we were already for this PR.

@@ -1,13 +1,15 @@
# Context version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider moving this file into contexts/loglikelihoods.jl.

@@ -1399,90 +1399,3 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)

return indices
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a replacement mechanism for setval_and_resample!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer needed, since the _and_resample! part is now just defined by using a SamplingContext.

src/loglikelihoods.jl Outdated Show resolved Hide resolved
src/loglikelihoods.jl Outdated Show resolved Hide resolved
src/loglikelihoods.jl Outdated Show resolved Hide resolved
src/loglikelihoods.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

devmotion commented May 31, 2021

Some notes for myself for the meeting (no need to reply in advance, will probably also be extended and modified):

General comment:

  • PR mixes different changes that do not necessarily belong together (e.g., SamplingContext does not necessarily require changes of the tilde pipeline)
    • (+/-) potentially reduces number of breaking releases - however, e.g., SamplingContext could be introduced in a non-breaking way
    • (-) makes it more difficult to review and discuss one change and design choice separately

Do we want SamplingContext to be an inner-most context as in this PR, or an outer-most context as in dw (compare) ?

  • (?) in this PR one can specify an RNG, a sampler, and a SamplingContext separately - only wrap the RNG and sampler internally and don't expose SamplingContext? Maybe problematic if it is the inner-most context?
    • Regarding the comments below: Currently, one can call model(SomeRNG, VarInfo(), SomeSampler, SamplingContext(SomeOtherRNG, SomeOtherSampler)), which seems undesired. Instead, if SamplingContext is only constructed internally, one could define model(SomeRNG, VarInfo(), SomeSampler, NoSamplingContext) and only construct SamplingContext(SomeRNG, SomeSampler, NoSamplingContext) internally.
  • (?) sometimes outer contexts might want to disable sampling, e.g., when using user-provided values instead - seems difficult to change leaf context (primitive) dynamically and to propagate it to outer contexts?
    • Regarding comments below: propagation to outer contexts is problematic in both approaches.

abstract type WrappedContext{LeafCtx} <: AbstractContext

  • (?) should LikelihoodContext etc. actually be a WrappedContext?
    • Regarding comments below: the suggestion is to perform sampling + evaluate the log-likelihood for SamplingContext{<:LikelihoodContext} and only evaluate the likelihood otherwise. More generall, one would add implementations for OtherContext{<:LikelihoodContext} instead of LikelihoodContext{<:OtherContext}. I.e., it would provide the same functionality as the approach with WrappedContexts.
  • (?) traits such as unwrap and rewrap might be sufficient?
  • (+) dispatch on leaf context
  • (+) introduces some structure for contexts that wrap other contexts
  • (-) restricts types of contexts (no multiple inheritance)
  • (-) type of leaf context might not be sufficient (instance contains information about type and specific properties)
  • (-) dispatching is not sufficient if some possibly nested wrapped contexts use a, possibly different, leaf type dynamically - of course, can be achieved with rewrap but type parameter is not sufficient/incorrect

abstract type PrimitiveContext <: AbstractContext

  • (+) structure and guarantees
  • (-) restrictive: does not allow use of arbitrary contexts and type hierarchies

EvaluationContext: only evaluation of the model is to take place, i.e. no sampling of variables.

  • (?) not needed if it is the default behaviour for LikelihoodContext etc
  • (-) still contains a sampler field (remove it?)

Rename tilde_assume (tilde_observe) and to tilde_assume! (tilde_observe!), and tilde(rng, ...) (tilde(...)) to tilde_assume(rng, ...) (tilde_observe(...)).

tilde_assume! simply calls tilde_assume followed by acclogp(varinfo, result_from_tilde_assume), so the ! here is to indicate that it's mutating the logp field in VarInfo.

  • (+) Not necessary to add acclogp! to implementations of tilde_assume
  • (-) If wrappers are supposed to work with arbitrary user-defined contexts, they have to either assume that wrapped contexts implement tilde_assume etc. as well instead of tilde_assume! or they have to implement tilde_assume!
    • E.g., in this PR the implementation for PriorContextPrefixContext and MinibatchContext seems unnecessarily restrictive since it only covers the tilde_assume entrypoints
    • Alternative: only use tilde_assume and assume etc. - the first one uses acclogp!, the second only returns the log density; most wrappers would implement tilde_assume etc. and only end points such as LikelihoodContext would implement assume

Add additional value/left argument to tilde_assume (not tilde_assume!) and assume, as in dot_tilde_assume, which is initialized to nothing by tilde_assume!

  • (?) it seems dot_tilde_assume should use left instead of nothing as the default?
  • (?) maybe not enough information (currently) propagated to the outer-most tilde_assume or tilde_assume! where the value would be saved (e.g., require not only values but also indices and distributions which could be modified in inner contexts)
  • (+) easy to intercept and modify values without touching or writing to VarInfo
  • (-) tilde_assume entrypoint does not cover contexts that implement tilde_assume!
  • (-) if dot_tilde_assume uses left, maybe nothing with assume would create inconsistencies (maybe does not matter)

Alternatively, it might be sufficient (?) to be able to pass a value only to the non-wrapper contexts. An example with the non-WrapperContext approach:

struct ConditionContext{names,V<:NamedTuple{names},C<:AbstractContext} <: AbstractContext
   values::V
   context::C
end

function unwrap_childcontext(context::ConditionContext{names}) where {names}
    child = context.context
    function reconstruct_conditioncontext(c::AbstractContext)
        return ConditionContext(context.values, c)
    end
    return child, reconstruct_conditioncontext
end

# drop `SamplingContext` if not needed
function tilde_assume(context::SamplingContext{<:ConditionContext{names}}, right, vn::VarName{sym}, inds, vi) where {sym,names}
    return if sym in names
        tilde_assume(context.context, right, vn, inds, vi)
    else
        c, reconstruct_context = unwrap_childcontext(context)
        child_of_c, _ = unwrap_childcontext(c)
        tilde_assume(reconstruct_context(child_of_c), right, vn, inds, vi)    
    end
end

# propagate `ConditionContext` to leaf contexts such as `JointContext` etc.
function tilde_assume(context::ConditionContext{names}, right, vn::VarName{sym}, inds, vi) where {names,sym}
    c, reconstruct_context = unwrap_childcontext(context)
    child_of_c, reconstruct_c = unwrap_childcontext(c) 
    return if child_of_c === nothing
        if sym in names
            # pass the value to the leaf context
            tilde_assume(c, right, vn, inds, vi, context.values[sym])
        else
            # pass no value
            tilde_assume(c, right, vn, inds, vi)
        end
    else
        tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi)    
    end
end

@torfjelde
Copy link
Member Author

torfjelde commented May 31, 2021

Some notes for myself for the meeting (no need to reply in advance, will probably also be extended and modified):

I'm going to do it anyways, even if just to organize my own thoughts on your notes:)

Do we want SamplingContext to be an inner-most context as in this PR, or an outer-most context as in dw (compare) ?

  • (?) in this PR one can specify an RNG, a sampler, and a SamplingContext separately - only wrap the RNG and sampler internally and don't expose SamplingContext? Maybe problematic if it is the inner-most context?

A bit uncertain what you mean here. AFAIK we wrap the RNG and sampler in SamplingContext from wherever, no?

  • (?) sometimes outer contexts might want to disable sampling, e.g., when using user-provided values instead - seems difficult to change leaf context (primitive) dynamically and to propagate it to outer contexts?

You mean situations where you want to do something like

function tilde_assume(ctx::MyContext, ...)
    f(...)
    value, logp = tilde_assume(childcontext(ctx), ...)
    g(...)

    return value
end

where the "primitive" context can be changed within the nested tilde_assume and thus the change does not propagate to the above scope, meaning that g won't know?

A couple of notes on this:

  1. Do we have examples of cases where we want this, i.e. a post-hook which depends on whether we sample or evaluate?
  2. This could be resolved in a couple of ways:
    1. Return the primitive context used from tilde_assume.
    2. [PARTIAL RESOLUTION] Allow rewrap using WrappedContext, i.e. you essentially insert a context at the end of the context-queue, thus allowing contexts to add more context to the tail of the context-queue.
  3. How does this work for the outer-most approach though? AFAIK this suffers from the same problem (unless we also return the primitive context). Say we have a sequence of nested contexts A(B(C())) where B decides to not sample and A has a "post-assume"-hook a la g above, and so we get the following call-stack:
    1. tilde_assume(::SamplingContext{<:A}, ...)
    2. tilde_assume(::A{<:SamplingContext}, ...)
      • From A's perspective, we're still sampling despite this not being the case, right?
    3. tilde_assume(::SamplingContext{<:B}, ...)
    4. tilde_assume(::C, ...)

Or am I misunderstanding something here?

abstract type WrappedContext{LeafCtx} <: AbstractContext

  • (?) should LikelihoodContext etc. actually be a `WrappedContext?

IMO yes, because:

  1. Minimize the number of primitive contexts.
  2. Gives us the possibility of allowing sampling the latent variables while only computing the likelihood.

It also opens up for the opportunity of being a bit more explicit about what should be sampled, etc., e.g. we could use LikelihoodContext as a filter to only touch observe statements but we're still allowed to combine it with SamplingContext to "force" them to be treated as missing, i.e. it can be used to only sample "observations". This and (2) can be boiled down to: IMO there are cases where one could imagine it making sense to have some sampling despite using a LikelihoodContext.

  • (?) traits such as unwrap and rewrap might be sufficient?

Technically, yes, but:

  1. It still seems like a bit "much" to have to pass this around everywhere.
  2. If we're going with the approach in this PR, IMO it makes sense to have a clear separation between what is considered a primitive context and what is not. If we go with the outer-most approach for SamplingContext, I agree that the type-hierarchy is probably unnecessary 👍
  • (-) not possible for wrapped contexts to use a, possibly different, leaf type dynamically

A bit confused here; why can't this be achieved through rewrap?

EvaluationContext: only evaluation of the model is to take place, i.e. no sampling of variables.

  • (?) not needed if it is the default behaviour for LikelihoodContext etc

True, but as mentioned above, I think there are cases where you'd like to also sample.

  • (-) still contains a sampler field (remove it?)

Very much in favour of removing it if possible. I added/kept it because AFAIK it's still being used for filtering based on variables so figured it was necessary to ensure that logp computations were performed correctly.

Rename tilde_assume (tilde_observe) and to tilde_assume! (tilde_observe!), and tilde(rng, ...) (tilde(...)) to tilde_assume(rng, ...) (tilde_observe(...)).
tilde_assume! simply calls tilde_assume followed by acclogp(varinfo, result_from_tilde_assume), so the ! here is to indicate that it's mutating the logp field in VarInfo.

  • (-) If wrappers are supposed to work with arbitrary user-defined contexts, they have to either assume that wrapped contexts implement tilde_assume etc. as well instead of tilde_assume! or they have to implement tilde_assume!

    • E.g., in this PR the implementation for PriorContext and MinibatchContext seems unnecessarily restrictive since it only covers the tilde_assume entrypoints
    • Alternative: only use tilde_assume and assume etc. - the first one uses acclogp!, the second only returns the log density; most wrappers would implement tilde_assume etc. and only end points such as LikelihoodContext would implement assume

How do you now which context is the "first one" though? Also, are you suggesting that LikelihoodContext should be passed to assume or just that it should be the only context calling assume?

EDIT: Ah, you mean the "inner-most" one/"primitive" context. Yeah actually this might be a great idea! Don't see a reason why this wouldn't work:)

EDIT 2: Eeeh actually it's bit awkward with contexts such as MiniBatch. I see what you've done btw, but I'm not certain I like it. It's sort of weird + it leads to a lot of unnecessary setlogp! and acclogp! calls that are completely unnecessary in nested contexts.

Btw, one thing I don't think I've heard your thoughts on yet @devmotion: what do you think about the suggestion of passing a value argument through the tilde_assume and assume statements, giving us the ability to completely override the value used + the implementer of assume often doesn't even have to touch vi?

@torfjelde
Copy link
Member Author

Just to summarize, we decided to split the work into several different PRs:

  1. [Merged by Bors] - Simplification of tilde-callstack #252 simplifies the tilde-callstack as in this PR.
  2. Introduction of SamplingContext #253 introduces SamplingContext using the outer-most approach.
  3. WrappedContext post introduction of SamplingContext #254 introduces WrappedContext in a post Introduction of SamplingContext #253 world. This is currently just a draft though with the aim of figuring out if it's worth it or not, as the consensus was that the benefit of the approach was unclear/not significant enough compared to the less intrusive approach in Introduction of SamplingContext #253 .

The final part is then passing an additional value/left argument to the tilde-statements to allow overriding values without touching VarInfo.

Add additional value/left argument to tilde_assume (not tilde_assume!) and assume, as in dot_tilde_assume, which is initialized to nothing by tilde_assume!

  • (?) it seems dot_tilde_assume should use left instead of nothing as the default?

AFAIK, the left argument isn't necessary for dot_tilde_assume. As mentioned before, the assignments in the dot_assume are redundant, as this mutation occurs in the model-scope anyways.
Hence, we could just "drop" it, i.e. make it nothing + the reason why we make it nothing is so that you could do stuff like sample the value if it is not present, i.e. is nothing; if we make it anything but nothing then we need some other way of figuring out whether or not to sample.

  • (?) maybe not enough information (currently) propagated to the outer-most tilde_assume or tilde_assume! where the value would be saved (e.g., require not only values but also indices and distributions which could be modified in inner contexts)

You mean to be returned? If so, yeah I agree. But I would say that this would fall under the "allowing post-hooks" in the tilde-statments, i.e. maybe something to think about after we have the override-feature as it could be trivially added by just returning more information.

  • (-) tilde_assume entrypoint does not cover contexts that implement tilde_assume!

We could insert the nothing in the model though, if the nothing approach is the one we want to go with.

  • (-) if dot_tilde_assume uses left, maybe nothing with assume would create inconsistencies (maybe does not matter)

As mentioned above, I'm pretty certain we can just remove left now though, i.e. we don't need it for anything else and so could use it for the same purpose as we then would in assume.

Alternatively, it might be sufficient (?) to be able to pass a value only to the non-wrapper contexts. An example with the non-WrapperContext approach:

struct ConditionContext{names,V<:NamedTuple{names},C<:AbstractContext} <: AbstractContext
   values::V
   context::C
end

function unwrap_childcontext(context::ConditionContext{names}) where {names}
    child = context.context
    function reconstruct_conditioncontext(c::AbstractContext)
        return ConditionContext(context.values, c)
    end
    return child, reconstruct_conditioncontext
end

# drop `SamplingContext` if not needed
function tilde_assume(context::SamplingContext{<:ConditionContext{names}}, right, vn::VarName{sym}, inds, vi) where {sym,names}
    return if sym in names
        tilde_assume(context.context, right, vn, inds, vi)
    else
        c, reconstruct_context = unwrap_childcontext(context)
        child_of_c, _ = unwrap_childcontext(c)
        tilde_assume(reconstruct_context(child_of_c), right, vn, inds, vi)    
    end
end

# propagate `ConditionContext` to leaf contexts such as `JointContext` etc.
function tilde_assume(context::ConditionContext{names}, right, vn::VarName{sym}, inds, vi) where {names,sym}
    c, reconstruct_context = unwrap_childcontext(context)
    child_of_c, reconstruct_c = unwrap_childcontext(c) 
    return if child_of_c === nothing
        if sym in names
            # pass the value to the leaf context
            tilde_assume(c, right, vn, inds, vi, context.values[sym])
        else
            # pass no value
            tilde_assume(c, right, vn, inds, vi)
        end
    else
        tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi)    
    end
end

Won't this lead to StackOverflowError if we have nested ConditionContext? Not that it would be recommended to do something like condition(condition(model, (a = 1, )), (a = 2, )), but one can imagine scenarios where a package returns a conditioned model and the user wants to override a particular value. (Btw, this is the sort of thing I had in mind when I said "what if you want more contexts that act similarily to SamplingContext?" earlier in our discussions.)

@torfjelde
Copy link
Member Author

Closing this in favour of the PRs listed above.

@torfjelde torfjelde closed this Jun 7, 2021
bors bot pushed a commit that referenced this pull request Jun 9, 2021
This PR introduces the simplification of the tilde-callstack as discussed in #249.

Copy-pasted from there:
- Remove unnecessary complexity in `~` implementation.
  - Current calling hierarchy for a `~` statement is:
    - `tilde_assume` -> `tilde(rng, ...)` -> `_tilde(rng, ...)` -> `assume`
    - `tilde_observe` -> `tilde(...)` -> `_tilde(...)` -> `observe`
    - Similarly for `dot_tilde_assume` and `dot_tilde_observe`.
  - This is super-confusing and difficult to debug.
  - `_tilde` is currently only used for `NamedDist` to allow overriding the variable-name used for a particular `~` statement.
  - Propose the following changes:
    - Remove `_tilde` and handle `NamedDist` _before_ calling `tilde_assume`, etc. by using a `unpack_right_vns` (and `unpack_right_left_vns` for dot-statements) (thanks to @devmotion)
    - Rename `tilde_assume` (`tilde_observe`) and to `tilde_assume!` (`tilde_observe!`), and `tilde(rng, ...)` (`tilde(...)`) to `tilde_assume(rng, ...)` (`tilde_observe(...)`).
      - `tilde_assume!` simply calls `tilde_assume` followed by `acclogp(varinfo, result_from_tilde_assume)`, so the `!` here is to indicate that it's mutating the `logp` field in `VarInfo`.


Co-authored-by: Hong Ge <hg344@cam.ac.uk>
@yebai yebai deleted the tor/wrappedcontext branch January 31, 2022 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants