Skip to content

Commit

Permalink
Merge branch 'main' into stack-based-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nossleinad authored Jun 19, 2024
2 parents 0c962e0 + 27f1539 commit 4a39c6d
Show file tree
Hide file tree
Showing 17 changed files with 551 additions and 33 deletions.
58 changes: 57 additions & 1 deletion docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,60 @@ Coming soon.

## Continuous models

## Compound models
## Compound models

## Lazy models

### LazyPartition

```@docs; canonical=false
LazyPartition
```

#### Examples

##### Example 1: Initializing for an upward pass
Now, we show how to wrap the `CodonPartition`s from [Example 3: FUBAR](@ref) with `LazyPartition`:

You simply go from initializing messages like this:
```julia
initial_partition = CodonPartition(Int64(length(seqs[1])/3))
initial_partition.state .= eq_freqs
populate_tree!(tree,initial_partition,seqnames,seqs)
```

To this
```julia
initial_partition = CodonPartition(Int64(length(seqs[1])/3))
initial_partition.state .= eq_freqs
lazy_initial_partition = LazyPartition{CodonPartition}()
populate_tree!(tree,lazy_initial_partition,seqnames,seqs)
lazyprep!(tree, initial_partition)
```

By this slight modification, we go from initializing and using 554 partitions to 6 during the subsequent `log_likelihood!` and `felsenstein!` calls. There is no significant decrease in performance recorded from this switch.

##### Example 2: Initializing for a downward pass
Now, we show how to wrap the `GaussianPartition`s from [Quick example: Likelihood calculations under phylogenetic Brownian motion:](@ref) with `LazyPartition`:

You simply go from initializing messages like this:
```julia
internal_message_init!(tree, GaussianPartition())
```

To this (technically we only add 1 LOC)
```julia
initial_partition = GaussianPartition()
lazy_initial_partition = LazyPartition{GaussianPartition}()
internal_message_init!(tree, lazy_initial_partition)
lazyprep!(tree, initial_partition, direction=LazyDown(isleafnode))
```
!!! note
Now, we provided a direction for `lazyprep!`. The direction is an instance of `LazyDown`, which was initialized with the `isleafnode` function. The function `isleafnode` dictates if a node saves its sampled observation after a down pass. If you use `direction=LazyDown()`, every node saves its observation.

#### Surrounding LazyPartition
```@docs; canonical=false
lazyprep!
LazyUp
LazyDown
```
3 changes: 3 additions & 0 deletions src/MolecularEvolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ abstract type StatePath end

abstract type UnivariateOpt end

abstract type LazyDirection end

#include("core/core.jl")
include("core/nodes/nodes.jl")
include("core/algorithms/algorithms.jl")
Expand Down Expand Up @@ -131,6 +133,7 @@ export
copy_partition_to!,
copy_partition,
copy_message,
partition_from_template,
equilibrium_message,
sample_partition!,
obs2partition!,
Expand Down
22 changes: 11 additions & 11 deletions src/core/algorithms/ancestors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
if run_fel_up
felsenstein!(tree, model_func, partition_list = partition_list)
Expand All @@ -56,7 +56,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -76,7 +76,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -91,7 +91,7 @@ end

#For marginal reconstructions
function reconstruct_marginal_node!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list,
Expand All @@ -109,7 +109,7 @@ end

export marginal_state_dict
"""
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their marginal reconstructions (ie. P(state|all observations,model)). A subset of partitions can be specified by partition_list,
Expand All @@ -119,7 +119,7 @@ function marginal_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -133,7 +133,7 @@ end
#For joint max reconstructions
export dependent_reconstruction!
function dependent_reconstruction!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list;
Expand Down Expand Up @@ -173,7 +173,7 @@ reconstruct_cascading_max_node!(node_message_dict, node, model_array, partition_
)
export cascading_max_state_dict
"""
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their inferred ancestors under the following scheme: the state that maximizes the marginal likelihood is selected at the root,
Expand All @@ -184,7 +184,7 @@ function cascading_max_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -206,7 +206,7 @@ conditioned_sample_node!(node_message_dict, node, model_array, partition_list) =
)
export endpoint_conditioned_sample_state_dict
"""
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and draws samples under the model
conditions on the leaf observations. These samples are stored in the node_message_dict, which is returned. A subset of partitions can be specified by partition_list, and a
Expand All @@ -216,7 +216,7 @@ function endpoint_conditioned_sample_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand Down
6 changes: 3 additions & 3 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#Model list should be a list of P matrices.
function branch_LL_up(
bl::Real,
temp_message::Vector{Partition},
temp_message::Vector{<:Partition},
node::FelNode,
model_list::Vector{<:BranchModel},
partition_list,
Expand All @@ -22,8 +22,8 @@ end
#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength.
#This is for cases where the user stores node-level parameters and wants to optimize them.
function branchlength_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list,
Expand Down
4 changes: 2 additions & 2 deletions src/core/algorithms/generative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ function sample_down!(node::FelNode, models, partition_list)
if isroot(node)
forward!(node.message[part], node.parent_message[part], model_list[part], node)
else
forward!(node.message[part], node.parent.message[part], model_list[part], node)
forward!(node.message[part], node.parent.message[part], model_list[part], node) #node.parent['.' vs. '_']message[part]
end
sample_partition!(node.message[part])
end
if !isleafnode(node)
for child in reverse(node.children)
for child in reverse(node.children) #We push! in reverse order because of LazyPartition, so that lazysort! is optimal for both felsenstein! and sample_down!
push!(stack, child)
end
end
Expand Down
12 changes: 6 additions & 6 deletions src/core/algorithms/nni_optim.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@


function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list;
Expand Down Expand Up @@ -72,8 +72,8 @@ end

#Unsure if this is the best choice to handle the model,models, and model_func stuff.
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models::Vector{<:BranchModel},
partition_list;
Expand All @@ -89,8 +89,8 @@ function nni_optim!(
)
end
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
model::BranchModel,
partition_list;
Expand Down
4 changes: 2 additions & 2 deletions src/core/nodes/AbstractTreeNode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,15 @@ function gettreefromnewick(str, T::DataType; tagged = false, disable_binarize =
i += 1
elseif c == ';'
try_apply_char_arr(currnode, char_arr)
return (tagged ? (currnode, tag_dict) : currnode)
break
else
push!(char_arr, c)
#println(char_arr)
i += 1
end
end

binarize!(currnode)
!disable_binarize && binarize!(currnode)

return (tagged ? (currnode, tag_dict) : currnode)
end
Expand Down
9 changes: 9 additions & 0 deletions src/models/compound_models/swm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ function copy_partition(src::SWMPartition{PType}) where {PType <: MultiSiteParti
return SWMPartition{PType}(copy_partition.(src.parts), copy(src.weights), src.sites, src.states, src.models)
end

#Overloading the partition_from_template with (indirect) usage of undef
function partition_from_template(partition_template::SWMPartition{PType}) where {PType <: MultiSitePartition}
return SWMPartition{PType}(partition_from_template.(partition_template.parts),
copy(partition_template.weights),
partition_template.sites,
partition_template.states,
partition_template.models)
end

function combine!(dest::SWMPartition{PType},src::SWMPartition{PType}) where {PType<:MultiSitePartition}
for i in 1:length(dest.parts)
combine!(dest.parts[i], src.parts[i])
Expand Down
6 changes: 6 additions & 0 deletions src/models/discrete_models/discrete_partitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ function copy_partition(src::T) where {T<:DiscretePartition}
return T(copy(src.state), src.states, src.sites, copy(src.scaling))
end

#Overloading the partition_from_template with usage of undef.
function partition_from_template(partition_template::T) where {T <: DiscretePartition}
states, sites = partition_template.states, partition_template.sites
return T(Array{Float64, 2}(undef, states, sites), states, sites, Array{Float64, 1}(undef, sites))
end

#I should add a constructor that constructs a DiscretePartition from an existing array.
mutable struct CustomDiscretePartition <: DiscretePartition
state::Array{Float64,2}
Expand Down
Loading

0 comments on commit 4a39c6d

Please sign in to comment.