Skip to content

Commit

Permalink
Merge pull request #26 from nossleinad/message-type-declaration
Browse files Browse the repository at this point in the history
Be consistent with message type declaration
  • Loading branch information
nossleinad authored Jun 19, 2024
2 parents 8a8c9b0 + 42e0a35 commit 27f1539
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 20 deletions.
22 changes: 11 additions & 11 deletions src/core/algorithms/ancestors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
if run_fel_up
felsenstein!(tree, model_func, partition_list = partition_list)
Expand All @@ -56,7 +56,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -76,7 +76,7 @@ function depth_first_reconstruction(
run_fel_up = true,
run_fel_down = true,
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
depth_first_reconstruction(
tree,
Expand All @@ -91,7 +91,7 @@ end

#For marginal reconstructions
function reconstruct_marginal_node!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list,
Expand All @@ -109,7 +109,7 @@ end

export marginal_state_dict
"""
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
marginal_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their marginal reconstructions (ie. P(state|all observations,model)). A subset of partitions can be specified by partition_list,
Expand All @@ -119,7 +119,7 @@ function marginal_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -133,7 +133,7 @@ end
#For joint max reconstructions
export dependent_reconstruction!
function dependent_reconstruction!(
node_message_dict::Dict{FelNode,Vector{Partition}},
node_message_dict::Dict{FelNode,Vector{<:Partition}},
node::FelNode,
model_array::Vector{<:BranchModel},
partition_list;
Expand Down Expand Up @@ -173,7 +173,7 @@ reconstruct_cascading_max_node!(node_message_dict, node, model_array, partition_
)
export cascading_max_state_dict
"""
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
cascading_max_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and
returns a dictionary mapping nodes to their inferred ancestors under the following scheme: the state that maximizes the marginal likelihood is selected at the root,
Expand All @@ -184,7 +184,7 @@ function cascading_max_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand All @@ -206,7 +206,7 @@ conditioned_sample_node!(node_message_dict, node, model_array, partition_list) =
)
export endpoint_conditioned_sample_state_dict
"""
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{Partition}}())
endpoint_conditioned_sample_state_dict(tree::FelNode, model; partition_list = 1:length(tree.message), node_message_dict = Dict{FelNode,Vector{<:Partition}}())
Takes in a tree and a model (which can be a single model, an array of models, or a function that maps FelNode->Array{<:BranchModel}), and draws samples under the model
conditions on the leaf observations. These samples are stored in the node_message_dict, which is returned. A subset of partitions can be specified by partition_list, and a
Expand All @@ -216,7 +216,7 @@ function endpoint_conditioned_sample_state_dict(
tree::FelNode,
model;
partition_list = 1:length(tree.message),
node_message_dict = Dict{FelNode,Vector{Partition}}(),
node_message_dict = Dict{FelNode,Vector{<:Partition}}(),
)
return depth_first_reconstruction(
tree,
Expand Down
6 changes: 3 additions & 3 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#Model list should be a list of P matrices.
function branch_LL_up(
bl::Real,
temp_message::Vector{Partition},
temp_message::Vector{<:Partition},
node::FelNode,
model_list::Vector{<:BranchModel},
partition_list,
Expand All @@ -22,8 +22,8 @@ end
#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength.
#This is for cases where the user stores node-level parameters and wants to optimize them.
function branchlength_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list,
Expand Down
12 changes: 6 additions & 6 deletions src/core/algorithms/nni_optim.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@


function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models,
partition_list;
Expand Down Expand Up @@ -72,8 +72,8 @@ end

#Unsure if this is the best choice to handle the model,models, and model_func stuff.
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
models::Vector{<:BranchModel},
partition_list;
Expand All @@ -89,8 +89,8 @@ function nni_optim!(
)
end
function nni_optim!(
temp_message::Vector{Partition},
message_to_set::Vector{Partition},
temp_message::Vector{<:Partition},
message_to_set::Vector{<:Partition},
node::FelNode,
model::BranchModel,
partition_list;
Expand Down
1 change: 1 addition & 0 deletions src/models/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function partition_from_template(partition_template::T) where {T <: DiscretePart
end
=#

#Note: not enforcing a return type causes some unnecesarry conversions
copy_message(msg::Vector{<:Partition}) = [copy_partition(x) for x in msg]

#This is a function shared for all models - perhaps move this elsewhere
Expand Down
17 changes: 17 additions & 0 deletions test/message_type_stability.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
begin
# Single partition example
tree = sim_tree(n = 10)
GAA_partition = GappyAminoAcidPartition(5)
AA_freqs = [1 / GAA_partition.states for _ = 1:GAA_partition.states]
GAA_partition.state .= AA_freqs
internal_message_init!(tree, GAA_partition)
Q = gappy_Q_from_symmetric_rate_matrix(WAGmatrix, 1.0, AA_freqs)
model = DiagonalizedCTMC(Q)
sample_down!(tree, model)
felsenstein!(tree, model)

# These would previously break since Vector{GappyAminoAcidPartition} is not <: Vector{Partition}, for example.
branchlength_optim!(tree, model)
marginal_state_dict(tree, model)
nni_optim!(tree, model)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ using Test
include("partition_selection.jl")
end

@testset "message_type_stability" begin
include("message_type_stability.jl")
end

end

0 comments on commit 27f1539

Please sign in to comment.