From d2e84c10a71d6daefdbfa2d983011341ad0558ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Sun, 14 Jul 2024 17:02:08 +0200 Subject: [PATCH 01/12] WIP --- src/MolecularEvolution.jl | 4 +- src/core/algorithms/branchlength_optim.jl | 109 +++++++++ src/core/algorithms/nni_optim.jl | 286 +++++++++++++++++++++- src/utils/misc.jl | 23 ++ 4 files changed, 420 insertions(+), 2 deletions(-) diff --git a/src/MolecularEvolution.jl b/src/MolecularEvolution.jl index 5afd96d..3726f4b 100644 --- a/src/MolecularEvolution.jl +++ b/src/MolecularEvolution.jl @@ -29,7 +29,9 @@ abstract type SimulationModel <: BranchModel end #Simulation models typically ca abstract type StatePath end -abstract type UnivariateOpt end +abstract type UnivariateModifier end +abstract type UnivariateOpt <: UnivariateModifier end +abstract type UnivariateSampler <: UnivariateModifier end abstract type LazyDirection end diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index 6984336..410ba61 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -125,3 +125,112 @@ branchlength_optim!( tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt() ) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) + +function branchlength_optim_v2!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + models, + partition_list, + tol; + modifier::UnivariateModifier = GoldenSectionOpt(), +) + + #This bit of code should be identical to the regular downward pass... + #------------------- + if !isleafnode(node) + model_list = models(node) + for part in partition_list + forward!(temp_message[part], node.parent_message[part], model_list[part], node) + end + for i = 1:length(node.children) + new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down + sib_inds = sibling_inds(node.children[i]) + for part in partition_list + combine!( + (node.children[i]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], + true, + ) + combine!( + (node.children[i]).parent_message[part], + [temp_message[part]], + false, + ) + end + #But calling branchlength_optim recursively... + branchlength_optim_v2!( + new_temp, + node.child_messages[i], + node.children[i], + models, + partition_list, + tol, + modifier=modifier + ) + end + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + end + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + + model_list = models(node) + fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) + bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_branchlength=node.branchlength) + + if fun(bl) <= fun(node.branchlength) && typeof(modifier) <: UnivariateOpt + else + node.branchlength = bl + end + + #Consider checking for improvement, and bailing if none. + #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] + for part in partition_list + backward!(message_to_set[part], node.message[part], model_list[part], node) + end + end + #For debugging: + #println("$(node.nodeindex):$(node.branchlength)") +end + +#BM: Check if running felsenstein_down! makes a difference. +""" + branchlength_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier = GoldenSectionOpt()) + +Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. +Requires felsenstein!() to have been run first. +models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or +a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. +partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models). +tol is the absolute tolerance for the modifier which defaults to golden section search, and has Brent's method as an option by setting modifier=BrentsMethodOpt(). +""" +function branchlength_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) + temp_message = copy_message(tree.message) + message_to_set = copy_message(tree.message) + + if partition_list === nothing + partition_list = 1:length(tree.message) + end + + branchlength_optim_v2!(temp_message, message_to_set, tree, models, partition_list, tol, modifier=modifier) +end + +#Overloading to allow for direct model and model vec inputs +branchlength_optim_v2!( + tree::FelNode, + models::Vector{<:BranchModel}; + partition_list = nothing, + tol = 1e-5, + modifier::UnivariateModifier = GoldenSectionOpt(), +) = branchlength_optim_v2!(tree, x -> models, partition_list = partition_list, tol = tol, modifier=modifier) +branchlength_optim_v2!( + tree::FelNode, + model::BranchModel; + partition_list = nothing, + tol = 1e-5, + modifier::UnivariateModifier = GoldenSectionOpt(), +) = branchlength_optim_v2!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index c239923..9b16d56 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -7,7 +7,7 @@ function nni_optim!( models, partition_list; acc_rule = (x, y) -> x > y, -) +) model_list = models(node) @@ -247,3 +247,287 @@ function nni_optim!( acc_rule = acc_rule, ) end + +function nni_optim_v2!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + models, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + + model_list = models(node) + + if isleafnode(node) + return + end + + #This bit of code should be identical to the regular downward pass... + #------------------- + + for part in partition_list + forward!(temp_message[part], node.parent_message[part], model_list[part], node) + end + @assert length(node.children) <= 2 + for i = 1:length(node.children) + new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down + sib_inds = sibling_inds(node.children[i]) + for part in partition_list + combine!( + (node.children[i]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], + true, + ) + combine!((node.children[i]).parent_message[part], [temp_message[part]], false) + end + #But calling branchlength_optim recursively... + nni_optim_v2!( + new_temp, + node.child_messages[i], + node.children[i], + models, + partition_list; + acc_rule = acc_rule, + sampler = sampler, + ) + end + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + nnid, exceed_sib, exceed_child = do_nni_v2( + node, + temp_message, + models; + partition_list = partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + backward!(message_to_set[part], node.message[part], model_list[part], node) + combine!( + node.parent.message[part], + [mess[part] for mess in node.parent.child_messages], + true, + ) + end + end +end +#Unsure if this is the best choice to handle the model,models, and model_func stuff. +function nni_optim_v2!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + models::Vector{<:BranchModel}, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + nni_optim_v2!( + temp_message, + message_to_set, + node, + x -> models, + partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) +end +function nni_optim_v2!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + model::BranchModel, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + nni_optim_v2!( + temp_message, + message_to_set, + node, + x -> [model], + partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) +end + +function do_nni_v2( + node, + temp_message, + models::F; + partition_list = 1:length(node.message), + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), + ) where {F<:Function} + + if length(node.children) == 0 || node.parent === nothing + return false + else + temp_message2 = copy_message(temp_message) + model_list = models(node) + #current score + for part in partition_list + backward!(temp_message[part], node.message[part], model_list[part], node) + combine!(temp_message[part], [node.parent_message[part]], false) + end + #@toggleable_function assert_message_consistency(node, models, p = 0.01) + + curr_LL = sum([total_LL(temp_message[part]) #+ + #total_LL(node.message[part]) + + #total_LL(node.parent_message[part]) + for part in partition_list]) + + changed = false + nni_LLs = [curr_LL] + nni_configs = [(0,0)] + + for sib_ind in + [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] + switch_LL = 0.0 + for child_ind = 1:length(node.children) + for part in partition_list + #move the sibling message, after upward propogation, to temp_message to work with it + combine!( + temp_message[part], + [node.parent.child_messages[sib_ind][part]], + true, + ) + + #combine this message, with all child messages of node except the index replaced + combine!( + temp_message[part], + [ + mess[part] for + (i, mess) in enumerate(node.child_messages) if i != child_ind + ], + false, + ) + + #prop up the message on the node up to its parent + backward!( + temp_message2[part], + temp_message[part], + model_list[part], + node, + ) + + #combine the message of the moved child + combine!( + temp_message2[part], + [node.child_messages[child_ind][part]], + false, + ) + + #we now have both parts of message, propogated to the parent of node + #propogate it up one more step, then merge it with parent_message of parent + backward!( + temp_message[part], + temp_message2[part], + model_list[part], + node.parent, + ) + combine!(temp_message[part], [node.parent.parent_message[part]], false) + end + + + LL = sum([total_LL(temp_message[part]) for part in partition_list]) + + push!(nni_LLs, LL) + push!(nni_configs, (sib_ind, child_ind)) + + end + end + + changed, sampled_config_ind = sampler(nni_LLs) + sampled_config_LL = nni_LLs[sampled_config_ind] + (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] + changed = acc_rule(sampled_config_LL, curr_LL) && changed + + + # println("changed: ", changed, " inds, ", (sampled_sib_ind, sampled_child_ind), " sampled config ind, ", sampled_config_ind) + # println(" nni_config, ", nni_configs) + # println("softmax ", softmax(nni_LLs)) + + #do the actual move here + if !(changed) + return false, sampled_sib_ind, sampled_child_ind + else + sib = node.parent.children[sampled_sib_ind] + child = node.children[sampled_child_ind] + + child.parent = node.parent + sib.parent = node + + node.children[sampled_child_ind] = sib + node.parent.children[sampled_sib_ind] = child + + node.parent.child_messages[sampled_sib_ind], node.child_messages[sampled_child_ind] = + node.child_messages[sampled_child_ind], node.parent.child_messages[sampled_sib_ind] + + return true, sampled_sib_ind, sampled_child_ind + end + end +end + +""" + nni_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + +Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. +Requires felsenstein!() to have been run first. +models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or +a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. +partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). +acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. +sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. +""" +function nni_optim_v2!( + tree::FelNode, + models; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + temp_message = copy_message(tree.message) + message_to_set = copy_message(tree.message) + + if partition_list === nothing + partition_list = 1:length(tree.message) + end + + nni_optim_v2!( + temp_message, + message_to_set, + tree, + models, + partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) +end + +#Overloading to allow for direct model and model vec inputs +nni_optim_v2!( + tree::FelNode, + models::Vector{<:BranchModel}; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) = nni_optim_v2!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) +nni_optim_v2!( + tree::FelNode, + model::BranchModel; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) = nni_optim_v2!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) + + diff --git a/src/utils/misc.jl b/src/utils/misc.jl index 24d7e33..deabf8e 100644 --- a/src/utils/misc.jl +++ b/src/utils/misc.jl @@ -314,3 +314,26 @@ function write_nexus(fname::String,tree::FelNode) n.name = old_names[i] end end + +struct SimpleBranchlengthPeturbation <: UnivariateSampler + sigma +end + +function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=0, transform=unit_transform, tol=10e-5, kwargs...) + return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol) +end + +function univariate_modifier(fun, modifier::UnivariateSampler; curr_branchlength=0, kwargs...) + return univariate_sampler(fun, modifier, curr_branchlength) +end + +function univariate_sampler(fun, modifier::SimpleBranchlengthPeturbation, curr_branchlength) + noise = modifier.sigma*rand(Normal(0,1)) + log_prior(x) = pdf(Normal(-1,1), x) + proposal = exp(log(curr_branchlength)+noise) + if rand() <= exp(fun(proposal)+log_prior(proposal)-fun(curr_branchlength)-log_prior(curr_branchlength)) + return proposal + else + return curr_branchlength + end +end \ No newline at end of file From c28534ceed1bdeea7414dc76490f7e6332e34419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Tue, 23 Jul 2024 16:39:41 +0200 Subject: [PATCH 02/12] WIP --- src/MolecularEvolution.jl | 1 + src/bayes/bayes.jl | 1 + src/bayes/sampling.jl | 121 ++++++++++++++++++++++ src/core/algorithms/branchlength_optim.jl | 39 ++++--- src/core/algorithms/nni_optim.jl | 52 +++++----- src/core/nodes/FelNode.jl | 25 +++++ src/utils/misc.jl | 15 --- 7 files changed, 193 insertions(+), 61 deletions(-) create mode 100644 src/bayes/bayes.jl create mode 100644 src/bayes/sampling.jl diff --git a/src/MolecularEvolution.jl b/src/MolecularEvolution.jl index 3726f4b..a282bac 100644 --- a/src/MolecularEvolution.jl +++ b/src/MolecularEvolution.jl @@ -41,6 +41,7 @@ include("core/algorithms/algorithms.jl") include("core/sim_tree.jl") include("models/models.jl") include("utils/utils.jl") +include("bayes/bayes.jl") #Optional dependencies function __init__() diff --git a/src/bayes/bayes.jl b/src/bayes/bayes.jl new file mode 100644 index 0000000..f9201c5 --- /dev/null +++ b/src/bayes/bayes.jl @@ -0,0 +1 @@ +include("sampling.jl") \ No newline at end of file diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl new file mode 100644 index 0000000..ebfbaf6 --- /dev/null +++ b/src/bayes/sampling.jl @@ -0,0 +1,121 @@ + +export sample_posterior_phylo_topologies +""" + function sample_posterior_phylo_topologies( + initial_tree::FelNode, + models::Vector{<:BranchModel}, + num_of_samples; + burn_in=1000, + sample_interval=10, + collect_LLs = false, + midpoint_rooting=false, + ) + +Samples tree topologies from a posterior distribution. + +# Arguments +- `initial_tree`: An initial topology with (important!) the leaves populated with data, for the likelihood calculation. +- `models`: A list of branch models. +- `num_of_samples`: The number of tree samples drawn from the posterior. +- `burn_in`: The number of samples discarded at the start of the Markov Chain. +- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation). +- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees. +- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium). + +# Returns +- `samples`: The trees drawn from the posterior. +- `sample_LLs`: The associated log-likelihoods of the tree (optional). +""" +function sample_posterior_phylo_topologies( + initial_tree::FelNode, + models::Vector{<:BranchModel}, + num_of_samples; + burn_in=1000, + sample_interval=10, + collect_LLs = false, + midpoint_rooting=false, + ladderize = false, + ) + + sample_LLs = [] + samples = FelNode[] + tree = deepcopy(initial_tree) + iterations = burn_in + num_of_samples * sample_interval + + modifier = BranchlengthPerturbation(2.0,0,0) + + softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) + + for i=1:iterations + + # Updates the tree topolgy and branchlengths using Gibbs sampling. + nni_optim!(tree, models, acc_rule = (x,y) -> true, sampler = softmax_sampler) + branchlength_optim!(tree, models, modifier=modifier) + + if (i-burn_in) % sample_interval == 0 && i > burn_in + + push!(samples, shallow_copy_tree(tree)) + + if collect_LLs + push!(sample_LLs, log_likelihood!(tree, models)) + end + + end + + #REMOVE BEFORE PR + if i % 1000 == 0 || i == iterations + println(floor(i/iterations * 100)) + end + end + + if midpoint_rooting + for (i,sample) in enumerate(samples) + node, len = midpoint(sample) + samples[i] = reroot!(node, dist_above_child=len) + end + end + + if ladderize + for sample in samples + ladderize!(sample) + end + end + + if collect_LLs + return samples, sample_LLs + end + + return samples +end + +function softmax(x) + exp_x = exp.(x .- maximum(x)) # For numerical stability + return exp_x ./ sum(exp_x) +end + +mutable struct BranchlengthPerturbation <: UnivariateSampler + sigma + accepts + rejects +end + +""" + univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) + +A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. +""" +function univariate_sampler(LL, modifier::BranchlengthPerturbation, curr_branchlength) + # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. + log_prior(x) = logpdf(Normal(-1,1),log(x + 1e-12)) + # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. + noise = modifier.sigma*rand(Normal(0,1)) + proposal = exp(log(curr_branchlength)+noise) + # The standard Metropolis acceptance criterion. + if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_branchlength)-log_prior(curr_branchlength)) + modifier.accepts = modifier.accepts + 1 + return proposal + else + modifier.rejects = modifier.rejects + 1 + return curr_branchlength + end +end \ No newline at end of file diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index 410ba61..e80fd72 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -21,7 +21,7 @@ end #I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength. #This is for cases where the user stores node-level parameters and wants to optimize them. -function branchlength_optim!( +function branchlength_optim_old!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -54,7 +54,7 @@ function branchlength_optim!( ) end #But calling branchlength_optim recursively... - branchlength_optim!( + branchlength_optim_old!( new_temp, node.child_messages[i], node.children[i], @@ -90,7 +90,7 @@ end #BM: Check if running felsenstein_down! makes a difference. """ - branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) + branchlength_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. @@ -99,7 +99,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models). tol is the absolute tolerance for the bl_optimizer which defaults to golden section search, and has Brent's method as an option by setting bl_optimizer=BrentsMethodOpt(). """ -function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) +function branchlength_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) temp_message = copy_message(tree.message) message_to_set = copy_message(tree.message) @@ -107,26 +107,26 @@ function branchlength_optim!(tree::FelNode, models; partition_list = nothing, to partition_list = 1:length(tree.message) end - branchlength_optim!(temp_message, message_to_set, tree, models, partition_list, tol, bl_optimizer=bl_optimizer) + branchlength_optim_old!(temp_message, message_to_set, tree, models, partition_list, tol, bl_optimizer=bl_optimizer) end #Overloading to allow for direct model and model vec inputs -branchlength_optim!( +branchlength_optim_old!( tree::FelNode, models::Vector{<:BranchModel}; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt() -) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) -branchlength_optim!( +) = branchlength_optim_old!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) +branchlength_optim_old!( tree::FelNode, model::BranchModel; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt() -) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) +) = branchlength_optim_old!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) -function branchlength_optim_v2!( +function branchlength_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -159,7 +159,7 @@ function branchlength_optim_v2!( ) end #But calling branchlength_optim recursively... - branchlength_optim_v2!( + branchlength_optim!( new_temp, node.child_messages[i], node.children[i], @@ -182,8 +182,7 @@ function branchlength_optim_v2!( fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_branchlength=node.branchlength) - if fun(bl) <= fun(node.branchlength) && typeof(modifier) <: UnivariateOpt - else + if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) node.branchlength = bl end @@ -199,7 +198,7 @@ end #BM: Check if running felsenstein_down! makes a difference. """ - branchlength_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier = GoldenSectionOpt()) + branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier = GoldenSectionOpt()) Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. @@ -208,7 +207,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models). tol is the absolute tolerance for the modifier which defaults to golden section search, and has Brent's method as an option by setting modifier=BrentsMethodOpt(). """ -function branchlength_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) +function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) temp_message = copy_message(tree.message) message_to_set = copy_message(tree.message) @@ -216,21 +215,21 @@ function branchlength_optim_v2!(tree::FelNode, models; partition_list = nothing, partition_list = 1:length(tree.message) end - branchlength_optim_v2!(temp_message, message_to_set, tree, models, partition_list, tol, modifier=modifier) + branchlength_optim!(temp_message, message_to_set, tree, models, partition_list, tol, modifier=modifier) end #Overloading to allow for direct model and model vec inputs -branchlength_optim_v2!( +branchlength_optim!( tree::FelNode, models::Vector{<:BranchModel}; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt(), -) = branchlength_optim_v2!(tree, x -> models, partition_list = partition_list, tol = tol, modifier=modifier) -branchlength_optim_v2!( +) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, modifier=modifier) +branchlength_optim!( tree::FelNode, model::BranchModel; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt(), -) = branchlength_optim_v2!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file +) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 9b16d56..ac02def 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -1,6 +1,6 @@ -function nni_optim!( +function nni_optim_old!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -34,7 +34,7 @@ function nni_optim!( combine!((node.children[i]).parent_message[part], [temp_message[part]], false) end #But calling branchlength_optim recursively... - nni_optim!( + nni_optim_old!( new_temp, node.child_messages[i], node.children[i], @@ -51,7 +51,7 @@ function nni_optim!( #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. #------------------- if !isroot(node) - nnid, exceed_sib, exceed_child = do_nni( + nnid, exceed_sib, exceed_child = do_nni_old( node, temp_message, models; @@ -71,7 +71,7 @@ function nni_optim!( end #Unsure if this is the best choice to handle the model,models, and model_func stuff. -function nni_optim!( +function nni_optim_old!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -79,7 +79,7 @@ function nni_optim!( partition_list; acc_rule = (x, y) -> x > y, ) - nni_optim!( + nni_optim_old!( temp_message, message_to_set, node, @@ -88,7 +88,7 @@ function nni_optim!( acc_rule = acc_rule, ) end -function nni_optim!( +function nni_optim_old!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -96,7 +96,7 @@ function nni_optim!( partition_list; acc_rule = (x, y) -> x > y, ) - nni_optim!( + nni_optim_old!( temp_message, message_to_set, node, @@ -106,7 +106,7 @@ function nni_optim!( ) end -function do_nni( +function do_nni_old( node, temp_message, models::F; @@ -216,7 +216,7 @@ function do_nni( end """ - nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + nni_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. @@ -225,7 +225,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. """ -function nni_optim!( +function nni_optim_old!( tree::FelNode, models; partition_list = nothing, @@ -238,7 +238,7 @@ function nni_optim!( partition_list = 1:length(tree.message) end - nni_optim!( + nni_optim_old!( temp_message, message_to_set, tree, @@ -248,7 +248,7 @@ function nni_optim!( ) end -function nni_optim_v2!( +function nni_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -283,7 +283,7 @@ function nni_optim_v2!( combine!((node.children[i]).parent_message[part], [temp_message[part]], false) end #But calling branchlength_optim recursively... - nni_optim_v2!( + nni_optim!( new_temp, node.child_messages[i], node.children[i], @@ -301,7 +301,7 @@ function nni_optim_v2!( #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. #------------------- if !isroot(node) - nnid, exceed_sib, exceed_child = do_nni_v2( + nnid, exceed_sib, exceed_child = do_nni( node, temp_message, models; @@ -321,7 +321,7 @@ function nni_optim_v2!( end end #Unsure if this is the best choice to handle the model,models, and model_func stuff. -function nni_optim_v2!( +function nni_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -330,7 +330,7 @@ function nni_optim_v2!( acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), ) - nni_optim_v2!( + nni_optim!( temp_message, message_to_set, node, @@ -340,7 +340,7 @@ function nni_optim_v2!( sampler = sampler, ) end -function nni_optim_v2!( +function nni_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, @@ -349,7 +349,7 @@ function nni_optim_v2!( acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), ) - nni_optim_v2!( + nni_optim!( temp_message, message_to_set, node, @@ -360,7 +360,7 @@ function nni_optim_v2!( ) end -function do_nni_v2( +function do_nni( node, temp_message, models::F; @@ -479,7 +479,7 @@ function do_nni_v2( end """ - nni_optim_v2!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. @@ -489,7 +489,7 @@ partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. """ -function nni_optim_v2!( +function nni_optim!( tree::FelNode, models; partition_list = nothing, @@ -503,7 +503,7 @@ function nni_optim_v2!( partition_list = 1:length(tree.message) end - nni_optim_v2!( + nni_optim!( temp_message, message_to_set, tree, @@ -515,19 +515,19 @@ function nni_optim_v2!( end #Overloading to allow for direct model and model vec inputs -nni_optim_v2!( +nni_optim!( tree::FelNode, models::Vector{<:BranchModel}; partition_list = nothing, acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim_v2!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) -nni_optim_v2!( +) = nni_optim!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) +nni_optim!( tree::FelNode, model::BranchModel; partition_list = nothing, acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim_v2!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) +) = nni_optim!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) diff --git a/src/core/nodes/FelNode.jl b/src/core/nodes/FelNode.jl index 2cf10f2..ca0eea0 100644 --- a/src/core/nodes/FelNode.jl +++ b/src/core/nodes/FelNode.jl @@ -98,3 +98,28 @@ function mixed_type_equilibrium_message( end return out_mess end + +""" + shallow_copy_tree(root::FelNode)::FelNode + + Returns a copy of the a tree with only the names and branchlengths. +""" +function shallow_copy_tree(root::FelNode)::FelNode + + new_root = FelNode(root.branchlength, root.name) + stack = [(root, new_root)] + + while !isempty(stack) + + original_node, copied_node = pop!(stack) + + for child in original_node.children + new_child = FelNode(child.branchlength, child.name) + push!(copied_node.children, new_child) + new_child.parent = copied_node + push!(stack, (child, new_child)) + end + end + + return new_root +end \ No newline at end of file diff --git a/src/utils/misc.jl b/src/utils/misc.jl index deabf8e..c399c5d 100644 --- a/src/utils/misc.jl +++ b/src/utils/misc.jl @@ -315,10 +315,6 @@ function write_nexus(fname::String,tree::FelNode) end end -struct SimpleBranchlengthPeturbation <: UnivariateSampler - sigma -end - function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=0, transform=unit_transform, tol=10e-5, kwargs...) return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol) end @@ -326,14 +322,3 @@ end function univariate_modifier(fun, modifier::UnivariateSampler; curr_branchlength=0, kwargs...) return univariate_sampler(fun, modifier, curr_branchlength) end - -function univariate_sampler(fun, modifier::SimpleBranchlengthPeturbation, curr_branchlength) - noise = modifier.sigma*rand(Normal(0,1)) - log_prior(x) = pdf(Normal(-1,1), x) - proposal = exp(log(curr_branchlength)+noise) - if rand() <= exp(fun(proposal)+log_prior(proposal)-fun(curr_branchlength)-log_prior(curr_branchlength)) - return proposal - else - return curr_branchlength - end -end \ No newline at end of file From c129675af38dc1d054ba8132345bad550e9e3e56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Wed, 24 Jul 2024 15:06:07 +0200 Subject: [PATCH 03/12] WIP --- src/bayes/sampling.jl | 80 +++--- src/core/algorithms/branchlength_optim.jl | 5 +- src/core/algorithms/nni_optim.jl | 307 ++++++++++++++++++++-- src/core/nodes/FelNode.jl | 62 +++-- src/utils/misc.jl | 11 +- src/utils/simple_optim.jl | 4 + src/utils/simple_sample.jl | 21 ++ src/utils/utils.jl | 1 + 8 files changed, 417 insertions(+), 74 deletions(-) create mode 100644 src/utils/simple_sample.jl diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index ebfbaf6..f8f2c5b 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -1,7 +1,7 @@ -export sample_posterior_phylo_topologies +export metropolis_sample """ - function sample_posterior_phylo_topologies( + function metropolis_sample( initial_tree::FelNode, models::Vector{<:BranchModel}, num_of_samples; @@ -26,7 +26,7 @@ Samples tree topologies from a posterior distribution. - `samples`: The trees drawn from the posterior. - `sample_LLs`: The associated log-likelihoods of the tree (optional). """ -function sample_posterior_phylo_topologies( +function metropolis_sample( initial_tree::FelNode, models::Vector{<:BranchModel}, num_of_samples; @@ -42,19 +42,19 @@ function sample_posterior_phylo_topologies( tree = deepcopy(initial_tree) iterations = burn_in + num_of_samples * sample_interval - modifier = BranchlengthPerturbation(2.0,0,0) - - softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) + bl_modifier = BranchlengthPerturbation(2.0) + #old_softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) + softmax_sampler = x -> rand(Categorical(softmax(x))) for i=1:iterations # Updates the tree topolgy and branchlengths using Gibbs sampling. - nni_optim!(tree, models, acc_rule = (x,y) -> true, sampler = softmax_sampler) - branchlength_optim!(tree, models, modifier=modifier) + nni_optim!(tree, models, nni_config_sampler = softmax_sampler) + branchlength_optim!(tree, models, modifier=bl_modifier) if (i-burn_in) % sample_interval == 0 && i > burn_in - push!(samples, shallow_copy_tree(tree)) + push!(samples, copy_tree(tree, true)) if collect_LLs push!(sample_LLs, log_likelihood!(tree, models)) @@ -88,34 +88,50 @@ function sample_posterior_phylo_topologies( return samples end -function softmax(x) - exp_x = exp.(x .- maximum(x)) # For numerical stability - return exp_x ./ sum(exp_x) -end - -mutable struct BranchlengthPerturbation <: UnivariateSampler - sigma - accepts - rejects -end - -""" - univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) - -A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. -""" -function univariate_sampler(LL, modifier::BranchlengthPerturbation, curr_branchlength) +function branchlength_metropolis(LL, modifier, curr_value) # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. log_prior(x) = logpdf(Normal(-1,1),log(x + 1e-12)) # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. noise = modifier.sigma*rand(Normal(0,1)) - proposal = exp(log(curr_branchlength)+noise) + proposal = exp(log(curr_value)+noise) # The standard Metropolis acceptance criterion. - if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_branchlength)-log_prior(curr_branchlength)) - modifier.accepts = modifier.accepts + 1 + if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_value)-log_prior(curr_value)) + modifier.acc_ratio[1] = modifier.acc_ratio[1] + 1 return proposal else - modifier.rejects = modifier.rejects + 1 - return curr_branchlength + modifier.acc_ratio[2] = modifier.acc_ratio[2] + 1 + return curr_value + end +end + +export collect_leaf_dists +""" + collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) +""" +function collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) + distmats = [] + for tree in trees + push!(distmats, leaf_distmat(tree)) end -end \ No newline at end of file + return distmats +end + +""" + leaf_distmat(tree) + +Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names. +""" +function leaf_distmat(tree) + + distmat, node_dic = MolecularEvolution.tree2distances(tree) + + leaflist = getleaflist(tree) + + sort!(leaflist, by = x-> x.name) + + order = [node_dic[leaf] for leaf in leaflist] + + return distmat[order, order] +end + + diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index e80fd72..e9f8a16 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -180,7 +180,7 @@ function branchlength_optim!( model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_branchlength=node.branchlength) + bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) node.branchlength = bl @@ -232,4 +232,5 @@ branchlength_optim!( partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt(), -) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file +) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) + diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index ac02def..7083fbd 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -255,7 +255,7 @@ function nni_optim!( models, partition_list; acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), + nni_config_sampler = (x) -> argmax(x), ) model_list = models(node) @@ -290,7 +290,7 @@ function nni_optim!( models, partition_list; acc_rule = acc_rule, - sampler = sampler, + nni_config_sampler = nni_config_sampler, ) end #Then combine node.child_messages into node.message... @@ -307,7 +307,7 @@ function nni_optim!( models; partition_list = partition_list, acc_rule = acc_rule, - sampler = sampler, + nni_config_sampler = nni_config_sampler, ) for part in partition_list combine!(node.message[part], [mess[part] for mess in node.child_messages], true) @@ -328,7 +328,7 @@ function nni_optim!( models::Vector{<:BranchModel}, partition_list; acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), + nni_config_sampler = (x) -> argmax(x), ) nni_optim!( temp_message, @@ -337,7 +337,7 @@ function nni_optim!( x -> models, partition_list, acc_rule = acc_rule, - sampler = sampler, + nni_config_sampler = nni_config_sampler, ) end function nni_optim!( @@ -347,7 +347,7 @@ function nni_optim!( model::BranchModel, partition_list; acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), + nni_config_sampler = (x) -> argmax(x), ) nni_optim!( temp_message, @@ -356,11 +356,288 @@ function nni_optim!( x -> [model], partition_list, acc_rule = acc_rule, - sampler = sampler, + nni_config_sampler = nni_config_sampler, ) end function do_nni( + node, + temp_message, + models::F; + partition_list = 1:length(node.message), + acc_rule = (x, y) -> x > y, + nni_config_sampler = (x) -> argmax(x), + ) where {F<:Function} + + if length(node.children) == 0 || node.parent === nothing + return false + else + temp_message2 = copy_message(temp_message) + model_list = models(node) + #current score + for part in partition_list + backward!(temp_message[part], node.message[part], model_list[part], node) + combine!(temp_message[part], [node.parent_message[part]], false) + end + #@toggleable_function assert_message_consistency(node, models, p = 0.01) + + curr_LL = sum([total_LL(temp_message[part]) #+ + #total_LL(node.message[part]) + + #total_LL(node.parent_message[part]) + for part in partition_list]) + + changed = false + nni_LLs = [curr_LL] + nni_configs = [(0,0)] + + for sib_ind in + [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] + switch_LL = 0.0 + for child_ind = 1:length(node.children) + for part in partition_list + #move the sibling message, after upward propogation, to temp_message to work with it + combine!( + temp_message[part], + [node.parent.child_messages[sib_ind][part]], + true, + ) + + #combine this message, with all child messages of node except the index replaced + combine!( + temp_message[part], + [ + mess[part] for + (i, mess) in enumerate(node.child_messages) if i != child_ind + ], + false, + ) + + #prop up the message on the node up to its parent + backward!( + temp_message2[part], + temp_message[part], + model_list[part], + node, + ) + + #combine the message of the moved child + combine!( + temp_message2[part], + [node.child_messages[child_ind][part]], + false, + ) + + #we now have both parts of message, propogated to the parent of node + #propogate it up one more step, then merge it with parent_message of parent + backward!( + temp_message[part], + temp_message2[part], + model_list[part], + node.parent, + ) + combine!(temp_message[part], [node.parent.parent_message[part]], false) + end + + + LL = sum([total_LL(temp_message[part]) for part in partition_list]) + + push!(nni_LLs, LL) + push!(nni_configs, (sib_ind, child_ind)) + + end + end + + sampled_config_ind = nni_config_sampler(nni_LLs) + changed = sampled_config_ind != 1 + (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] + + #do the actual move here + if !(changed) + return false, sampled_sib_ind, sampled_child_ind + else + sib = node.parent.children[sampled_sib_ind] + child = node.children[sampled_child_ind] + + child.parent = node.parent + sib.parent = node + + node.children[sampled_child_ind] = sib + node.parent.children[sampled_sib_ind] = child + + node.parent.child_messages[sampled_sib_ind], node.child_messages[sampled_child_ind] = + node.child_messages[sampled_child_ind], node.parent.child_messages[sampled_sib_ind] + + return true, sampled_sib_ind, sampled_child_ind + end + end +end + +""" + nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + +Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. +Requires felsenstein!() to have been run first. +models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or +a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. +partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). +acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. +nni_config_sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. +""" +function nni_optim!( + tree::FelNode, + models; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + nni_config_sampler = (x) -> argmax(x), +) + temp_message = copy_message(tree.message) + message_to_set = copy_message(tree.message) + + if partition_list === nothing + partition_list = 1:length(tree.message) + end + + nni_optim!( + temp_message, + message_to_set, + tree, + models, + partition_list, + acc_rule = acc_rule, + nni_config_sampler = nni_config_sampler, + ) +end + +#Overloading to allow for direct model and model vec inputs +nni_optim!( + tree::FelNode, + models::Vector{<:BranchModel}; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + nni_config_sampler = (x) -> argmax(x), +) = nni_optim!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, nni_config_sampler = nni_config_sampler) +nni_optim!( + tree::FelNode, + model::BranchModel; + partition_list = nothing, + acc_rule = (x, y) -> x > y, + nni_config_sampler = (x) -> argmax(x), +) = nni_optim!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, nni_config_sampler = nni_config_sampler) + + +function nni_optim_temp!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + models, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + + model_list = models(node) + + if isleafnode(node) + return + end + + #This bit of code should be identical to the regular downward pass... + #------------------- + + for part in partition_list + forward!(temp_message[part], node.parent_message[part], model_list[part], node) + end + @assert length(node.children) <= 2 + for i = 1:length(node.children) + new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down + sib_inds = sibling_inds(node.children[i]) + for part in partition_list + combine!( + (node.children[i]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], + true, + ) + combine!((node.children[i]).parent_message[part], [temp_message[part]], false) + end + #But calling branchlength_optim recursively... + nni_optim_temp!( + new_temp, + node.child_messages[i], + node.children[i], + models, + partition_list; + acc_rule = acc_rule, + sampler = sampler, + ) + end + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + nnid, exceed_sib, exceed_child = do_nni_temp( + node, + temp_message, + models; + partition_list = partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + backward!(message_to_set[part], node.message[part], model_list[part], node) + combine!( + node.parent.message[part], + [mess[part] for mess in node.parent.child_messages], + true, + ) + end + end +end +#Unsure if this is the best choice to handle the model,models, and model_func stuff. +function nni_optim_temp!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + models::Vector{<:BranchModel}, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + nni_optim_temp!( + temp_message, + message_to_set, + node, + x -> models, + partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) +end +function nni_optim_temp!( + temp_message::Vector{<:Partition}, + message_to_set::Vector{<:Partition}, + node::FelNode, + model::BranchModel, + partition_list; + acc_rule = (x, y) -> x > y, + sampler = (x) -> (true, argmax(x[2:end]) + 1), +) + nni_optim_temp!( + temp_message, + message_to_set, + node, + x -> [model], + partition_list, + acc_rule = acc_rule, + sampler = sampler, + ) +end + +function do_nni_temp( node, temp_message, models::F; @@ -479,7 +756,7 @@ function do_nni( end """ - nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + nni_optim_temp!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. @@ -489,7 +766,7 @@ partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. """ -function nni_optim!( +function nni_optim_temp!( tree::FelNode, models; partition_list = nothing, @@ -503,7 +780,7 @@ function nni_optim!( partition_list = 1:length(tree.message) end - nni_optim!( + nni_optim_temp!( temp_message, message_to_set, tree, @@ -515,19 +792,17 @@ function nni_optim!( end #Overloading to allow for direct model and model vec inputs -nni_optim!( +nni_optim_temp!( tree::FelNode, models::Vector{<:BranchModel}; partition_list = nothing, acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) -nni_optim!( +) = nni_optim_temp!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) +nni_optim_temp!( tree::FelNode, model::BranchModel; partition_list = nothing, acc_rule = (x, y) -> x > y, sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) - - +) = nni_optim_temp!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) diff --git a/src/core/nodes/FelNode.jl b/src/core/nodes/FelNode.jl index ca0eea0..8cac51b 100644 --- a/src/core/nodes/FelNode.jl +++ b/src/core/nodes/FelNode.jl @@ -99,27 +99,55 @@ function mixed_type_equilibrium_message( return out_mess end +export copy_tree """ - shallow_copy_tree(root::FelNode)::FelNode - - Returns a copy of the a tree with only the names and branchlengths. -""" -function shallow_copy_tree(root::FelNode)::FelNode - - new_root = FelNode(root.branchlength, root.name) - stack = [(root, new_root)] + function copy_tree(root::FelNode, shallow_copy=false) +Returns a untangled copy of the a tree. Optionally, the flag `shallow_copy` can be used to obtained a copy of the tree with only the names and branchlengths. +""" +function copy_tree(root::FelNode, shallow_copy=false) + + root_copy = FelNode(root.branchlength, root.name) + stack = [(root, root_copy)] + while !isempty(stack) + node, node_copy = pop!(stack) - original_node, copied_node = pop!(stack) + if !shallow_copy - for child in original_node.children - new_child = FelNode(child.branchlength, child.name) - push!(copied_node.children, new_child) - new_child.parent = copied_node - push!(stack, (child, new_child)) + if isdefined(node, :nodeindex) + node_copy.nodeindex = node.nodeindex + end + if isdefined(node, :seqindex) + node_copy.seqindex = node.seqindex + end + if isdefined(node, :state_path) + node_copy.state_path = deepcopy(node.state_path) + end + if isdefined(node, :branch_params) + node_copy.branch_params = copy(node.branch_params) + end + if isdefined(node, :node_data) + node_copy.node_data = deepcopy(node.node_data) + end + if isdefined(node, :message) + node_copy.message = copy_message(node.message) + end + if isdefined(node, :parent_message) + node_copy.parent_message = copy_message(node.parent_message) + end + if isdefined(node, :child_messages) + node_copy.child_messages = [copy_message(msg) for msg in node.child_messages] + end + end + + for child in node.children + child_copy = FelNode(child.branchlength, child.name) + push!(stack, (child, child_copy)) + child_copy.parent = node_copy + push!(node_copy.children, child_copy) end end - - return new_root -end \ No newline at end of file + + return root_copy +end diff --git a/src/utils/misc.jl b/src/utils/misc.jl index c399c5d..ad6ea80 100644 --- a/src/utils/misc.jl +++ b/src/utils/misc.jl @@ -315,10 +315,7 @@ function write_nexus(fname::String,tree::FelNode) end end -function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=0, transform=unit_transform, tol=10e-5, kwargs...) - return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol) -end - -function univariate_modifier(fun, modifier::UnivariateSampler; curr_branchlength=0, kwargs...) - return univariate_sampler(fun, modifier, curr_branchlength) -end +function softmax(x) + exp_x = exp.(x .- maximum(x)) # For numerical stability + return exp_x ./ sum(exp_x) +end \ No newline at end of file diff --git a/src/utils/simple_optim.jl b/src/utils/simple_optim.jl index c2dd5d4..4945a34 100644 --- a/src/utils/simple_optim.jl +++ b/src/utils/simple_optim.jl @@ -15,6 +15,10 @@ end struct GoldenSectionOpt <: UnivariateOpt end struct BrentsMethodOpt <: UnivariateOpt end +function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=1, transform=unit_transform, tol=10e-5, kwargs...) + return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol) +end + """ Golden section search. diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl new file mode 100644 index 0000000..d0202e8 --- /dev/null +++ b/src/utils/simple_sample.jl @@ -0,0 +1,21 @@ + +function univariate_modifier(f, modifier::UnivariateSampler; curr_value=nothing, kwargs...) + return univariate_sampler(f, modifier, curr_value) +end + +struct BranchlengthPerturbation <: UnivariateSampler + sigma + #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. + acc_ratio + BranchlengthPerturbation(sigma) = new(sigma, [0,0]) +end + +""" + univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) + +A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. +""" +function univariate_sampler(LL, modifier::BranchlengthPerturbation, curr_branchlength) + return branchlength_metropolis(LL, modifier, curr_branchlength) +end + diff --git a/src/utils/utils.jl b/src/utils/utils.jl index cd69ae9..253395b 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,4 +1,5 @@ include("misc.jl") include("simple_optim.jl") +include("simple_sample.jl") include("tree_hash.jl") #fasta_io.jl is optionally included with Requires.jl in MolecularEvolution.jl From 64cc880ce27b750eed7c6adbe8ba2e7ee3da1977 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Mon, 29 Jul 2024 12:30:24 +0200 Subject: [PATCH 04/12] WIP --- src/bayes/sampling.jl | 8 +++++--- src/utils/simple_sample.jl | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index f8f2c5b..800e842 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -42,7 +42,7 @@ function metropolis_sample( tree = deepcopy(initial_tree) iterations = burn_in + num_of_samples * sample_interval - bl_modifier = BranchlengthPerturbation(2.0) + bl_modifier = BranchlengthSampler(Normal(0,2), Normal(-1,1)) #old_softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) softmax_sampler = x -> rand(Categorical(softmax(x))) @@ -81,6 +81,8 @@ function metropolis_sample( end end + println("acc_ratio = ", bl_modifier.acc_ratio[1]/sum(bl_modifier.acc_ratio)) + if collect_LLs return samples, sample_LLs end @@ -90,9 +92,9 @@ end function branchlength_metropolis(LL, modifier, curr_value) # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. - log_prior(x) = logpdf(Normal(-1,1),log(x + 1e-12)) + log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. - noise = modifier.sigma*rand(Normal(0,1)) + noise = rand(modifier.log_bl_proposal) proposal = exp(log(curr_value)+noise) # The standard Metropolis acceptance criterion. if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_value)-log_prior(curr_value)) diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index d0202e8..ae211e0 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -3,11 +3,12 @@ function univariate_modifier(f, modifier::UnivariateSampler; curr_value=nothing, return univariate_sampler(f, modifier, curr_value) end -struct BranchlengthPerturbation <: UnivariateSampler - sigma +struct BranchlengthSampler <: UnivariateSampler #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. acc_ratio - BranchlengthPerturbation(sigma) = new(sigma, [0,0]) + log_bl_proposal + log_bl_prior + BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) end """ @@ -15,7 +16,7 @@ end A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. """ -function univariate_sampler(LL, modifier::BranchlengthPerturbation, curr_branchlength) +function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength) return branchlength_metropolis(LL, modifier, curr_branchlength) end From a31bf41438e94c5edf4641eca453d60763bacd21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Mon, 12 Aug 2024 15:27:21 +0200 Subject: [PATCH 05/12] WIP --- src/bayes/sampling.jl | 7 +- src/core/algorithms/algorithms.jl | 5 + src/core/algorithms/branchlength_optim.jl | 231 ++++--- src/core/algorithms/nni_optim.jl | 767 +++++++--------------- 4 files changed, 366 insertions(+), 644 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index 800e842..c60c407 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -47,9 +47,9 @@ function metropolis_sample( #old_softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) softmax_sampler = x -> rand(Categorical(softmax(x))) for i=1:iterations - + # Updates the tree topolgy and branchlengths using Gibbs sampling. - nni_optim!(tree, models, nni_config_sampler = softmax_sampler) + nni_optim_iter!(tree, x -> models, nni_selection_rule = softmax_sampler) branchlength_optim!(tree, models, modifier=bl_modifier) if (i-burn_in) % sample_interval == 0 && i > burn_in @@ -81,7 +81,7 @@ function metropolis_sample( end end - println("acc_ratio = ", bl_modifier.acc_ratio[1]/sum(bl_modifier.acc_ratio)) + #println("acc_ratio = ", bl_modifier.acc_ratio[1]/sum(bl_modifier.acc_ratio)) if collect_LLs return samples, sample_LLs @@ -106,7 +106,6 @@ function branchlength_metropolis(LL, modifier, curr_value) end end -export collect_leaf_dists """ collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) """ diff --git a/src/core/algorithms/algorithms.jl b/src/core/algorithms/algorithms.jl index eb12de1..070f222 100644 --- a/src/core/algorithms/algorithms.jl +++ b/src/core/algorithms/algorithms.jl @@ -4,3 +4,8 @@ include("lls.jl") include("nni_optim.jl") include("ancestors.jl") include("generative.jl") + +#Maybe we should use safepop! for LazyPartition too? +function safepop!(temp_messages::Vector{Vector{T}}, temp_message::Vector{T}) where T <: Partition + return isempty(temp_messages) ? copy_message(temp_message) : pop!(temp_messages) +end diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index e9f8a16..4499e38 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -19,113 +19,116 @@ function branch_LL_up( return tot_LL end -#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength. -#This is for cases where the user stores node-level parameters and wants to optimize them. -function branchlength_optim_old!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, +function branchlength_optim_iter!( + temp_messages::Vector{Vector{T}}, + tree::FelNode, models, partition_list, tol; - bl_optimizer::UnivariateOpt = GoldenSectionOpt() -) - - #This bit of code should be identical to the regular downward pass... - #------------------- - if !isleafnode(node) - model_list = models(node) - for part in partition_list - forward!(temp_message[part], node.parent_message[part], model_list[part], node) - end - for i = 1:length(node.children) - new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down - sib_inds = sibling_inds(node.children[i]) + modifier::UnivariateModifier = GoldenSectionOpt(), + traversal = Iterators.reverse +) where {T <: Partition} + + stack = [(pop!(temp_messages), tree, 1, 1, true, true)] + while !isempty(stack) + temp_message, node, ind, lastind, first, down = pop!(stack) + #We start out with a regular downward pass... + #(except for some extra bookkeeping to track if node is visited for the first time) + #------------------- + if !isleafnode(node) + if down + if first + model_list = models(node) + for part in partition_list + forward!( + temp_message[part], + node.parent_message[part], + model_list[part], + node, + ) + end + #Temp must be constant between iterations for a node during down... + child_iter = traversal(1:length(node.children)) + lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) + push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up + for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? + push!(stack, (temp_message, node, i, lastind, false, true)) + end + end + if !first + sib_inds = sibling_inds(node.children[ind]) + for part in partition_list + combine!( + (node.children[ind]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], + true, + ) + combine!( + (node.children[ind]).parent_message[part], + [temp_message[part]], + false, + ) + end + #But calling branchlength_optim! recursively... (the iterative equivalent) + push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! + ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp + end + end + if !down + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + temp_message = pop!(temp_messages) + model_list = models(node) + fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) + bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) + node.branchlength = bl + end + + #Consider checking for improvement, and bailing if none. + #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] + for part in partition_list + backward!( + node.parent.child_messages[ind][part], + node.message[part], + model_list[part], + node, + ) + end + push!(temp_messages, temp_message) + end + end + else + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + model_list = models(node) + fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) + bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) + node.branchlength = bl + end + #Consider checking for improvement, and bailing if none. + #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] for part in partition_list - combine!( - (node.children[i]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!( - (node.children[i]).parent_message[part], - [temp_message[part]], - false, + backward!( + node.parent.child_messages[ind][part], + node.message[part], + model_list[part], + node, ) end - #But calling branchlength_optim recursively... - branchlength_optim_old!( - new_temp, - node.child_messages[i], - node.children[i], - models, - partition_list, - tol, - bl_optimizer=bl_optimizer - ) - end - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + push!(temp_messages, temp_message) end end - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - model_list = models(node) - fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol) - if fun(opt) > fun(node.branchlength) - node.branchlength = opt - end - #Consider checking for improvement, and bailing if none. - #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] - for part in partition_list - backward!(message_to_set[part], node.message[part], model_list[part], node) - end - end - #For debugging: - #println("$(node.nodeindex):$(node.branchlength)") end -#BM: Check if running felsenstein_down! makes a difference. -""" - branchlength_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) - -Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. -Requires felsenstein!() to have been run first. -models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or -a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. -partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models). -tol is the absolute tolerance for the bl_optimizer which defaults to golden section search, and has Brent's method as an option by setting bl_optimizer=BrentsMethodOpt(). -""" -function branchlength_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt()) - temp_message = copy_message(tree.message) - message_to_set = copy_message(tree.message) - - if partition_list === nothing - partition_list = 1:length(tree.message) - end - - branchlength_optim_old!(temp_message, message_to_set, tree, models, partition_list, tol, bl_optimizer=bl_optimizer) -end - -#Overloading to allow for direct model and model vec inputs -branchlength_optim_old!( - tree::FelNode, - models::Vector{<:BranchModel}; - partition_list = nothing, - tol = 1e-5, - bl_optimizer::UnivariateOpt = GoldenSectionOpt() -) = branchlength_optim_old!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) -branchlength_optim_old!( - tree::FelNode, - model::BranchModel; - partition_list = nothing, - tol = 1e-5, - bl_optimizer::UnivariateOpt = GoldenSectionOpt() -) = branchlength_optim_old!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer=bl_optimizer) - +#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength. +#This is for cases where the user stores node-level parameters and wants to optimize them. function branchlength_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, @@ -133,7 +136,7 @@ function branchlength_optim!( models, partition_list, tol; - modifier::UnivariateModifier = GoldenSectionOpt(), + modifier::UnivariateModifier = GoldenSectionOpt() ) #This bit of code should be identical to the regular downward pass... @@ -177,15 +180,12 @@ function branchlength_optim!( #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. #------------------- if !isroot(node) - model_list = models(node) - fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) + fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) - if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) node.branchlength = bl end - #Consider checking for improvement, and bailing if none. #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] for part in partition_list @@ -196,11 +196,27 @@ function branchlength_optim!( #println("$(node.nodeindex):$(node.branchlength)") end +#= +and maybe there can be an alternative call that uses keyword arguments like shuffle = true, for the common cases +(that just calls the traversal with the right passed in function) +=# +function branchlength_optim_iter!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse) + sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed + temp_messages = [copy_message(tree.message)] + + if partition_list === nothing + partition_list = 1:length(tree.message) + end + + branchlength_optim_iter!(temp_messages, tree, models, partition_list, tol, modifier=modifier, traversal = traversal) +end + #BM: Check if running felsenstein_down! makes a difference. """ - branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier = GoldenSectionOpt()) + branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) -Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. +Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. +Alternativly, the branches may be sampled (instead of optimized) by letting the modifier be some subtype of UnivariateSampler. Requires felsenstein!() to have been run first. models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. @@ -224,13 +240,12 @@ branchlength_optim!( models::Vector{<:BranchModel}; partition_list = nothing, tol = 1e-5, - modifier::UnivariateModifier = GoldenSectionOpt(), + modifier::UnivariateModifier = GoldenSectionOpt() ) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, modifier=modifier) branchlength_optim!( tree::FelNode, model::BranchModel; partition_list = nothing, tol = 1e-5, - modifier::UnivariateModifier = GoldenSectionOpt(), -) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) - + modifier::UnivariateModifier = GoldenSectionOpt() +) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 7083fbd..46a830e 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -1,261 +1,250 @@ - - -function nni_optim_old!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, +#= +About clades getting skipped: +- the iterative implementation perfectly mimics the recursive one (they can both skip clades) +- some nnis can lead to some clades not getting optimized and some getting optimized multiple times +- I could push "every other" during first down and use lastind to know if a clade's been visisted, if a sibling clade's not been visited, I'll simply not fel-up yet but continue down +- +- Sanity checks: compare switch_LL with log_likelihood! of deepcopied tree with said switch +full_traversal passed the sanity check +=# + +#After a do_nni, we have to update parent_message if we want to continue down (assume that temp_message is the forwarded parent.parent_message) +function update_parent_message!( node::FelNode, - models, - partition_list; - acc_rule = (x, y) -> x > y, -) - - model_list = models(node) - - if isleafnode(node) - return - end - - #This bit of code should be identical to the regular downward pass... - #------------------- - + temp_message::Vector{<:Partition}; + partition_list = 1:length(node.message), +) + sib_inds = sibling_inds(node) for part in partition_list - forward!(temp_message[part], node.parent_message[part], model_list[part], node) - end - @assert length(node.children) <= 2 - for i = 1:length(node.children) - new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down - sib_inds = sibling_inds(node.children[i]) - for part in partition_list - combine!( - (node.children[i]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!((node.children[i]).parent_message[part], [temp_message[part]], false) - end - #But calling branchlength_optim recursively... - nni_optim_old!( - new_temp, - node.child_messages[i], - node.children[i], - models, - partition_list; - acc_rule = acc_rule, + combine!( + node.parent_message[part], + [mess[part] for mess in node.parent.child_messages[sib_inds]], + true, ) - end - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - nnid, exceed_sib, exceed_child = do_nni_old( - node, - temp_message, - models; - partition_list = partition_list, - acc_rule = acc_rule, + combine!( + node.parent_message[part], + [temp_message[part]], + false, ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(message_to_set[part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end end end -#Unsure if this is the best choice to handle the model,models, and model_func stuff. -function nni_optim_old!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - models::Vector{<:BranchModel}, - partition_list; - acc_rule = (x, y) -> x > y, -) - nni_optim_old!( - temp_message, - message_to_set, - node, - x -> models, - partition_list, - acc_rule = acc_rule, - ) -end -function nni_optim_old!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - model::BranchModel, +function nni_optim_full_traversal!( + temp_messages::Vector{Vector{T}}, + tree::FelNode, + models, partition_list; - acc_rule = (x, y) -> x > y, -) - nni_optim_old!( - temp_message, - message_to_set, - node, - x -> [model], - partition_list, - acc_rule = acc_rule, - ) -end - -function do_nni_old( - node, - temp_message, - models::F; - partition_list = 1:length(node.message), - acc_rule = (x, y) -> x > y, -) where {F<:Function} - if length(node.children) == 0 || node.parent === nothing - return false - else - temp_message2 = copy_message(temp_message) - model_list = models(node) - #current score - for part in partition_list - backward!(temp_message[part], node.message[part], model_list[part], node) - combine!(temp_message[part], [node.parent_message[part]], false) + nni_selection_rule = (x) -> argmax(x), + traversal = Iterators.reverse +) where {T <: Partition} + + #Consider a NamedTuple/struct + stack = [(pop!(temp_messages), tree, 1, 1, true, true)] + while !isempty(stack) + temp_message, node, ind, lastind, first, down = pop!(stack) + #We start out with a regular downward pass... + #(except for some extra bookkeeping to track if node is visited for the first time) + #------------------- + if isleafnode(node) + push!(temp_messages, temp_message) + continue end - #@toggleable_function assert_message_consistency(node, models, p = 0.01) - - curr_LL = sum([total_LL(temp_message[part]) #+ - #total_LL(node.message[part]) + - #total_LL(node.parent_message[part]) - for part in partition_list]) - - max_LL = -Inf - exceeded, exceed_sib, exceed_child = (false, 0, 0) - - for sib_ind in - [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] - switch_LL = 0.0 - for child_ind = 1:length(node.children) + if down + if first + model_list = models(node) for part in partition_list - #move the sibling message, after upward propogation, to temp_message to work with it - combine!( + forward!( temp_message[part], - [node.parent.child_messages[sib_ind][part]], + node.parent_message[part], + model_list[part], + node, + ) + end + @assert length(node.children) <= 2 + #Temp must be constant between iterations for a node during down... + child_iter = traversal(1:length(node.children)) + lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) + push!(stack, (Vector{T}(), node, ind, lastind, true, false)) #... but not up + for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? + push!(stack, (temp_message, node, i, lastind, false, true)) + end + end + if !first + sib_inds = sibling_inds(node.children[ind]) + for part in partition_list + combine!( + (node.children[ind]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], true, ) - - #combine this message, with all child messages of node except the index replaced combine!( - temp_message[part], - [ - mess[part] for - (i, mess) in enumerate(node.child_messages) if i != child_ind - ], + (node.children[ind]).parent_message[part], + [temp_message[part]], false, ) - - #prop up the message on the node up to its parent - backward!( - temp_message2[part], - temp_message[part], - model_list[part], + end + #But calling nni_optim! recursively... (the iterative equivalent) + push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! + ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp + end + end + if !down + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + temp_message = pop!(temp_messages) + model_list = models(node) + if first #We only do_nni first up + nnid, exceed_sib, exceed_child = do_nni( node, + temp_message, + models; + partition_list = partition_list, + nni_selection_rule = nni_selection_rule, ) - - #combine the message of the moved child + if nnid && last(last(stack)) #We nnid a sibling that hasn't been visited (then, down would be true in the next iter)... + #... and now we want to continue down the nnid sibling (now a child to node) + push!(temp_messages, temp_message) + temp_message = Base.first(last(stack)) #The forwarded parent.parent_message + #First we update the parent_message... + update_parent_message!( + node, + temp_message; + partition_list = partition_list, + ) + #... then we forward the updated parent_message (this resembles a first down) + model_list = models(node) + for part in partition_list + forward!( + temp_message[part], + node.parent_message[part], + model_list[part], + node, + ) + end + pop!(stack) + push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #When we're going up a second time, we no longer need a temp + push!(stack, (temp_message, node, exceed_child, exceed_child, false, true)) #Go to the "new" child - the "new" lastind + continue #Don't fel-up yet + end + end + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node) combine!( - temp_message2[part], - [node.child_messages[child_ind][part]], - false, + node.parent.message[part], + [mess[part] for mess in node.parent.child_messages], + true, ) + end + push!(temp_messages, temp_message) + end + end + end +end - #we now have both parts of message, propogated to the parent of node - #propogate it up one more step, then merge it with parent_message of parent - backward!( +function nni_optim_iter!( + temp_messages::Vector{Vector{T}}, + tree::FelNode, + models, + partition_list; + nni_selection_rule = (x) -> argmax(x), + traversal = Iterators.reverse +) where {T <: Partition} + + #Consider a NamedTuple/struct + stack = [(pop!(temp_messages), tree, 1, 1, true, true)] + while !isempty(stack) + temp_message, node, ind, lastind, first, down = pop!(stack) + #We start out with a regular downward pass... + #(except for some extra bookkeeping to track if node is visited for the first time) + #------------------- + if isleafnode(node) + push!(temp_messages, temp_message) + continue + end + if down + if first + model_list = models(node) + for part in partition_list + forward!( temp_message[part], - temp_message2[part], + node.parent_message[part], model_list[part], - node.parent, + node, ) - combine!(temp_message[part], [node.parent.parent_message[part]], false) end - - switch_LL = sum([total_LL(temp_message[part]) for part in partition_list]) - - - if switch_LL > max_LL - exceed_sib = sib_ind - exceed_child = child_ind - max_LL = switch_LL + @assert length(node.children) <= 2 + #Temp must be constant between iterations for a node during down... + child_iter = traversal(1:length(node.children)) + lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) + push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up + for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? + push!(stack, (temp_message, node, i, lastind, false, true)) + end + end + if !first + sib_inds = sibling_inds(node.children[ind]) + for part in partition_list + combine!( + (node.children[ind]).parent_message[part], + [mess[part] for mess in node.child_messages[sib_inds]], + true, + ) + combine!( + (node.children[ind]).parent_message[part], + [temp_message[part]], + false, + ) end + #But calling nni_optim! recursively... (the iterative equivalent) + push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! + ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp end end - - exceeded = acc_rule(max_LL, curr_LL) - - #do the actual move here, switching exceed child and exceed sib - if !(exceeded) - return false, exceed_sib, exceed_child - else - sib = node.parent.children[exceed_sib] - child = node.children[exceed_child] - - child.parent = node.parent - sib.parent = node - - node.children[exceed_child] = sib - node.parent.children[exceed_sib] = child - - node.parent.child_messages[exceed_sib], node.child_messages[exceed_child] = - node.child_messages[exceed_child], node.parent.child_messages[exceed_sib] - - return true, exceed_sib, exceed_child + if !down + #Then combine node.child_messages into node.message... + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + end + #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. + #------------------- + if !isroot(node) + temp_message = pop!(temp_messages) + model_list = models(node) + nnid, exceed_sib, exceed_child = do_nni( + node, + temp_message, + models; + partition_list = partition_list, + nni_selection_rule = nni_selection_rule, + ) + for part in partition_list + combine!(node.message[part], [mess[part] for mess in node.child_messages], true) + backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node) + combine!( + node.parent.message[part], + [mess[part] for mess in node.parent.child_messages], + true, + ) + end + push!(temp_messages, temp_message) + end end end end -""" - nni_optim_old!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) - -Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. -Requires felsenstein!() to have been run first. -models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or -a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. -partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). -acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. -""" -function nni_optim_old!( - tree::FelNode, - models; - partition_list = nothing, - acc_rule = (x, y) -> x > y, -) - temp_message = copy_message(tree.message) - message_to_set = copy_message(tree.message) - - if partition_list === nothing - partition_list = 1:length(tree.message) - end - - nni_optim_old!( - temp_message, - message_to_set, - tree, - models, - partition_list, - acc_rule = acc_rule, - ) -end - function nni_optim!( temp_message::Vector{<:Partition}, message_to_set::Vector{<:Partition}, node::FelNode, models, partition_list; - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), + nni_selection_rule = (x) -> argmax(x), ) model_list = models(node) @@ -289,8 +278,7 @@ function nni_optim!( node.children[i], models, partition_list; - acc_rule = acc_rule, - nni_config_sampler = nni_config_sampler, + nni_selection_rule = nni_selection_rule, ) end #Then combine node.child_messages into node.message... @@ -306,8 +294,7 @@ function nni_optim!( temp_message, models; partition_list = partition_list, - acc_rule = acc_rule, - nni_config_sampler = nni_config_sampler, + nni_selection_rule = nni_selection_rule, ) for part in partition_list combine!(node.message[part], [mess[part] for mess in node.child_messages], true) @@ -320,6 +307,7 @@ function nni_optim!( end end end + #Unsure if this is the best choice to handle the model,models, and model_func stuff. function nni_optim!( temp_message::Vector{<:Partition}, @@ -327,8 +315,7 @@ function nni_optim!( node::FelNode, models::Vector{<:BranchModel}, partition_list; - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), + nni_selection_rule = (x) -> argmax(x), ) nni_optim!( temp_message, @@ -336,8 +323,7 @@ function nni_optim!( node, x -> models, partition_list, - acc_rule = acc_rule, - nni_config_sampler = nni_config_sampler, + nni_selection_rule = nni_selection_rule, ) end function nni_optim!( @@ -346,8 +332,7 @@ function nni_optim!( node::FelNode, model::BranchModel, partition_list; - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), + nni_selection_rule = (x) -> argmax(x), ) nni_optim!( temp_message, @@ -355,8 +340,7 @@ function nni_optim!( node, x -> [model], partition_list, - acc_rule = acc_rule, - nni_config_sampler = nni_config_sampler, + nni_selection_rule = nni_selection_rule, ) end @@ -365,8 +349,7 @@ function do_nni( temp_message, models::F; partition_list = 1:length(node.message), - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), + nni_selection_rule = (x) -> argmax(x), ) where {F<:Function} if length(node.children) == 0 || node.parent === nothing @@ -447,7 +430,7 @@ function do_nni( end end - sampled_config_ind = nni_config_sampler(nni_LLs) + sampled_config_ind = nni_selection_rule(nni_LLs) changed = sampled_config_ind != 1 (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] @@ -472,306 +455,43 @@ function do_nni( end end -""" - nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) - -Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. -Requires felsenstein!() to have been run first. -models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or -a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. -partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). -acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. -nni_config_sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. -""" -function nni_optim!( - tree::FelNode, - models; - partition_list = nothing, - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), -) - temp_message = copy_message(tree.message) - message_to_set = copy_message(tree.message) +function nni_optim_full_traversal!(tree::FelNode, models; partition_list = nothing, nni_selection_rule = (x) -> argmax(x), sort_tree = false, traversal = Iterators.reverse) + sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed + temp_messages = [copy_message(tree.message)] if partition_list === nothing partition_list = 1:length(tree.message) end - nni_optim!( - temp_message, - message_to_set, - tree, - models, - partition_list, - acc_rule = acc_rule, - nni_config_sampler = nni_config_sampler, - ) + nni_optim_full_traversal!(temp_messages, tree, models, partition_list, nni_selection_rule = nni_selection_rule, traversal = traversal) end -#Overloading to allow for direct model and model vec inputs -nni_optim!( - tree::FelNode, - models::Vector{<:BranchModel}; - partition_list = nothing, - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), -) = nni_optim!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, nni_config_sampler = nni_config_sampler) -nni_optim!( - tree::FelNode, - model::BranchModel; - partition_list = nothing, - acc_rule = (x, y) -> x > y, - nni_config_sampler = (x) -> argmax(x), -) = nni_optim!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, nni_config_sampler = nni_config_sampler) - - -function nni_optim_temp!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - models, - partition_list; - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), -) - - model_list = models(node) - - if isleafnode(node) - return - end - - #This bit of code should be identical to the regular downward pass... - #------------------- - - for part in partition_list - forward!(temp_message[part], node.parent_message[part], model_list[part], node) - end - @assert length(node.children) <= 2 - for i = 1:length(node.children) - new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down - sib_inds = sibling_inds(node.children[i]) - for part in partition_list - combine!( - (node.children[i]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!((node.children[i]).parent_message[part], [temp_message[part]], false) - end - #But calling branchlength_optim recursively... - nni_optim_temp!( - new_temp, - node.child_messages[i], - node.children[i], - models, - partition_list; - acc_rule = acc_rule, - sampler = sampler, - ) - end - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end +function nni_optim_iter!(tree::FelNode, models; partition_list = nothing, nni_selection_rule = (x) -> argmax(x), sort_tree = false, traversal = Iterators.reverse) + sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed + temp_messages = [copy_message(tree.message)] - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - nnid, exceed_sib, exceed_child = do_nni_temp( - node, - temp_message, - models; - partition_list = partition_list, - acc_rule = acc_rule, - sampler = sampler, - ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(message_to_set[part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end + if partition_list === nothing + partition_list = 1:length(tree.message) end -end -#Unsure if this is the best choice to handle the model,models, and model_func stuff. -function nni_optim_temp!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - models::Vector{<:BranchModel}, - partition_list; - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), -) - nni_optim_temp!( - temp_message, - message_to_set, - node, - x -> models, - partition_list, - acc_rule = acc_rule, - sampler = sampler, - ) -end -function nni_optim_temp!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - model::BranchModel, - partition_list; - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), -) - nni_optim_temp!( - temp_message, - message_to_set, - node, - x -> [model], - partition_list, - acc_rule = acc_rule, - sampler = sampler, - ) -end - -function do_nni_temp( - node, - temp_message, - models::F; - partition_list = 1:length(node.message), - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), - ) where {F<:Function} - - if length(node.children) == 0 || node.parent === nothing - return false - else - temp_message2 = copy_message(temp_message) - model_list = models(node) - #current score - for part in partition_list - backward!(temp_message[part], node.message[part], model_list[part], node) - combine!(temp_message[part], [node.parent_message[part]], false) - end - #@toggleable_function assert_message_consistency(node, models, p = 0.01) - - curr_LL = sum([total_LL(temp_message[part]) #+ - #total_LL(node.message[part]) + - #total_LL(node.parent_message[part]) - for part in partition_list]) - - changed = false - nni_LLs = [curr_LL] - nni_configs = [(0,0)] - - for sib_ind in - [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] - switch_LL = 0.0 - for child_ind = 1:length(node.children) - for part in partition_list - #move the sibling message, after upward propogation, to temp_message to work with it - combine!( - temp_message[part], - [node.parent.child_messages[sib_ind][part]], - true, - ) - - #combine this message, with all child messages of node except the index replaced - combine!( - temp_message[part], - [ - mess[part] for - (i, mess) in enumerate(node.child_messages) if i != child_ind - ], - false, - ) - - #prop up the message on the node up to its parent - backward!( - temp_message2[part], - temp_message[part], - model_list[part], - node, - ) - - #combine the message of the moved child - combine!( - temp_message2[part], - [node.child_messages[child_ind][part]], - false, - ) - - #we now have both parts of message, propogated to the parent of node - #propogate it up one more step, then merge it with parent_message of parent - backward!( - temp_message[part], - temp_message2[part], - model_list[part], - node.parent, - ) - combine!(temp_message[part], [node.parent.parent_message[part]], false) - end - - - LL = sum([total_LL(temp_message[part]) for part in partition_list]) - - push!(nni_LLs, LL) - push!(nni_configs, (sib_ind, child_ind)) - - end - end - - changed, sampled_config_ind = sampler(nni_LLs) - sampled_config_LL = nni_LLs[sampled_config_ind] - (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] - changed = acc_rule(sampled_config_LL, curr_LL) && changed - - - # println("changed: ", changed, " inds, ", (sampled_sib_ind, sampled_child_ind), " sampled config ind, ", sampled_config_ind) - # println(" nni_config, ", nni_configs) - # println("softmax ", softmax(nni_LLs)) - #do the actual move here - if !(changed) - return false, sampled_sib_ind, sampled_child_ind - else - sib = node.parent.children[sampled_sib_ind] - child = node.children[sampled_child_ind] - - child.parent = node.parent - sib.parent = node - - node.children[sampled_child_ind] = sib - node.parent.children[sampled_sib_ind] = child - - node.parent.child_messages[sampled_sib_ind], node.child_messages[sampled_child_ind] = - node.child_messages[sampled_child_ind], node.parent.child_messages[sampled_sib_ind] - - return true, sampled_sib_ind, sampled_child_ind - end - end + nni_optim_iter!(temp_messages, tree, models, partition_list, nni_selection_rule = nni_selection_rule, traversal = traversal) end """ - nni_optim_temp!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). -acc_rule allows you to specify a function that takes the current and proposed log likelihoods, and if true is returned the move is accepted. -sampler allows you to randomly select a NNI based on the vector of log likelihoods of the possible interchanges, note that the current log likelihood is at index 1. +nni_selection_rule lets you choose which nni swap to do (including no swap) based on the log likelihoods of the different nni configurations. """ -function nni_optim_temp!( +function nni_optim!( tree::FelNode, models; partition_list = nothing, - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), + nni_selection_rule = (x) -> argmax(x), ) temp_message = copy_message(tree.message) message_to_set = copy_message(tree.message) @@ -780,29 +500,12 @@ function nni_optim_temp!( partition_list = 1:length(tree.message) end - nni_optim_temp!( + nni_optim!( temp_message, message_to_set, tree, models, partition_list, - acc_rule = acc_rule, - sampler = sampler, + nni_selection_rule = nni_selection_rule, ) -end - -#Overloading to allow for direct model and model vec inputs -nni_optim_temp!( - tree::FelNode, - models::Vector{<:BranchModel}; - partition_list = nothing, - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim_temp!(tree, x -> models, partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) -nni_optim_temp!( - tree::FelNode, - model::BranchModel; - partition_list = nothing, - acc_rule = (x, y) -> x > y, - sampler = (x) -> (true, argmax(x[2:end]) + 1), -) = nni_optim_temp!(tree, x -> [model], partition_list=partition_list, acc_rule=acc_rule, sampler = sampler) +end \ No newline at end of file From a94ef1c7cba1ba4a476ba67864e0a1da4a85ca48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Mon, 12 Aug 2024 16:18:45 +0200 Subject: [PATCH 06/12] WIP --- src/bayes/sampling.jl | 33 ++++++++++++--------------------- src/utils/simple_sample.jl | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index c60c407..ddefa3f 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -5,6 +5,7 @@ export metropolis_sample initial_tree::FelNode, models::Vector{<:BranchModel}, num_of_samples; + bl_modifier::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)) burn_in=1000, sample_interval=10, collect_LLs = false, @@ -17,6 +18,7 @@ Samples tree topologies from a posterior distribution. - `initial_tree`: An initial topology with (important!) the leaves populated with data, for the likelihood calculation. - `models`: A list of branch models. - `num_of_samples`: The number of tree samples drawn from the posterior. +- `bl_sampler`: Sampler used to drawn branchlengths from the posterior. - `burn_in`: The number of samples discarded at the start of the Markov Chain. - `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation). - `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees. @@ -30,6 +32,7 @@ function metropolis_sample( initial_tree::FelNode, models::Vector{<:BranchModel}, num_of_samples; + bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)), burn_in=1000, sample_interval=10, collect_LLs = false, @@ -37,20 +40,20 @@ function metropolis_sample( ladderize = false, ) + # The prior over the (log) of the branchlengths should be specified in bl_sampler. + # Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths). + sample_LLs = [] samples = FelNode[] tree = deepcopy(initial_tree) iterations = burn_in + num_of_samples * sample_interval - - bl_modifier = BranchlengthSampler(Normal(0,2), Normal(-1,1)) - - #old_softmax_sampler = x -> (sample = rand(Categorical(softmax(x))); changed = sample != 1; (changed, sample)) + softmax_sampler = x -> rand(Categorical(softmax(x))) for i=1:iterations - # Updates the tree topolgy and branchlengths using Gibbs sampling. + # Updates the tree topolgy and branchlengths. nni_optim_iter!(tree, x -> models, nni_selection_rule = softmax_sampler) - branchlength_optim!(tree, models, modifier=bl_modifier) + branchlength_optim_iter!(tree, x -> models, modifier=bl_sampler) if (i-burn_in) % sample_interval == 0 && i > burn_in @@ -90,24 +93,12 @@ function metropolis_sample( return samples end -function branchlength_metropolis(LL, modifier, curr_value) - # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. - log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) - # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. - noise = rand(modifier.log_bl_proposal) - proposal = exp(log(curr_value)+noise) - # The standard Metropolis acceptance criterion. - if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_value)-log_prior(curr_value)) - modifier.acc_ratio[1] = modifier.acc_ratio[1] + 1 - return proposal - else - modifier.acc_ratio[2] = modifier.acc_ratio[2] + 1 - return curr_value - end -end +# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes. """ collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) + + Returns a list of distance matrices (containing the distance between the leaf nodes) which can be used to assess mixing. """ function collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) distmats = [] diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index ae211e0..95de114 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -20,3 +20,19 @@ function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength return branchlength_metropolis(LL, modifier, curr_branchlength) end +function branchlength_metropolis(LL, modifier, curr_value) + # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. + log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) + # Adding additive normal symmetrical noise in the log(branchlength) domain to ensure the proposal function is symmetric. + noise = rand(modifier.log_bl_proposal) + proposal = exp(log(curr_value)+noise) + # The standard Metropolis acceptance criterion. + if rand() <= exp(LL(proposal)+log_prior(proposal)-LL(curr_value)-log_prior(curr_value)) + modifier.acc_ratio[1] = modifier.acc_ratio[1] + 1 + return proposal + else + modifier.acc_ratio[2] = modifier.acc_ratio[2] + 1 + return curr_value + end +end + From 461800960f65c39e4e4319b7b8e9eef3eb77fba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Mon, 19 Aug 2024 14:42:14 +0200 Subject: [PATCH 07/12] WIP --- src/bayes/sampling.jl | 4 +- src/core/algorithms/branchlength_optim.jl | 137 +++++------------- src/core/algorithms/nni_optim.jl | 167 ++++++---------------- 3 files changed, 78 insertions(+), 230 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index ddefa3f..ec804ab 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -52,8 +52,8 @@ function metropolis_sample( for i=1:iterations # Updates the tree topolgy and branchlengths. - nni_optim_iter!(tree, x -> models, nni_selection_rule = softmax_sampler) - branchlength_optim_iter!(tree, x -> models, modifier=bl_sampler) + nni_optim!(tree, x -> models, nni_selection_rule = softmax_sampler) + branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler) if (i-burn_in) % sample_interval == 0 && i > burn_in diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index 4499e38..79c95d0 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -19,13 +19,15 @@ function branch_LL_up( return tot_LL end -function branchlength_optim_iter!( +#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength. +#This is for cases where the user stores node-level parameters and wants to optimize them. +function branchlength_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models, partition_list, tol; - modifier::UnivariateModifier = GoldenSectionOpt(), + bl_modifier::UnivariateModifier = GoldenSectionOpt(), traversal = Iterators.reverse ) where {T <: Partition} @@ -85,11 +87,11 @@ function branchlength_optim_iter!( temp_message = pop!(temp_messages) model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) - if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) + bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) node.branchlength = bl end - + #Consider checking for improvement, and bailing if none. #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] for part in partition_list @@ -108,8 +110,8 @@ function branchlength_optim_iter!( #------------------- model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) - if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) + bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) node.branchlength = bl end #Consider checking for improvement, and bailing if none. @@ -127,111 +129,32 @@ function branchlength_optim_iter!( end end -#I need to add a version of this that takes a generic optimizer function and uses that instead of golden_section_maximize on just the branchlength. -#This is for cases where the user stores node-level parameters and wants to optimize them. -function branchlength_optim!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - models, - partition_list, - tol; - modifier::UnivariateModifier = GoldenSectionOpt() -) - - #This bit of code should be identical to the regular downward pass... - #------------------- - if !isleafnode(node) - model_list = models(node) - for part in partition_list - forward!(temp_message[part], node.parent_message[part], model_list[part], node) - end - for i = 1:length(node.children) - new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down - sib_inds = sibling_inds(node.children[i]) - for part in partition_list - combine!( - (node.children[i]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!( - (node.children[i]).parent_message[part], - [temp_message[part]], - false, - ) - end - #But calling branchlength_optim recursively... - branchlength_optim!( - new_temp, - node.child_messages[i], - node.children[i], - models, - partition_list, - tol, - modifier=modifier - ) - end - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - end - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - model_list = models(node) - fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) - if fun(bl) > fun(node.branchlength) || !(modifier isa UnivariateOpt) - node.branchlength = bl - end - #Consider checking for improvement, and bailing if none. - #Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one] - for part in partition_list - backward!(message_to_set[part], node.message[part], model_list[part], node) - end - end - #For debugging: - #println("$(node.nodeindex):$(node.branchlength)") -end - -#= -and maybe there can be an alternative call that uses keyword arguments like shuffle = true, for the common cases -(that just calls the traversal with the right passed in function) -=# -function branchlength_optim_iter!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse) - sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed - temp_messages = [copy_message(tree.message)] - - if partition_list === nothing - partition_list = 1:length(tree.message) - end - - branchlength_optim_iter!(temp_messages, tree, models, partition_list, tol, modifier=modifier, traversal = traversal) -end - #BM: Check if running felsenstein_down! makes a difference. """ - branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) + branchlength_optim!(tree::FelNode, models; ) -Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. -Alternativly, the branches may be sampled (instead of optimized) by letting the modifier be some subtype of UnivariateSampler. +Uses golden section search, or optionally Brent's method, to optimize all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. -partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models). -tol is the absolute tolerance for the modifier which defaults to golden section search, and has Brent's method as an option by setting modifier=BrentsMethodOpt(). + +# Keyword Arguments +- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option). +- `tol=1e-5`: absolute tolerance for the `bl_modifier`. +- `bl_modifier=GoldenSectionOpt()`: can either be a optimizer or a sampler (subtype of UnivariateModifier). For optimization, in addition to golden section search, Brent's method can be used by setting bl_modifier=BrentsMethodOpt(). +- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized. +- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable. +- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`. """ -function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, modifier::UnivariateModifier = GoldenSectionOpt()) - temp_message = copy_message(tree.message) - message_to_set = copy_message(tree.message) +function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false) + sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed + temp_messages = [copy_message(tree.message)] if partition_list === nothing partition_list = 1:length(tree.message) end - branchlength_optim!(temp_message, message_to_set, tree, models, partition_list, tol, modifier=modifier) + branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_modifier=bl_modifier, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal) end #Overloading to allow for direct model and model vec inputs @@ -240,12 +163,18 @@ branchlength_optim!( models::Vector{<:BranchModel}; partition_list = nothing, tol = 1e-5, - modifier::UnivariateModifier = GoldenSectionOpt() -) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, modifier=modifier) + bl_modifier::UnivariateModifier = GoldenSectionOpt(), + sort_tree = false, + traversal = Iterators.reverse, + shuffle = false +) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) branchlength_optim!( tree::FelNode, model::BranchModel; partition_list = nothing, tol = 1e-5, - modifier::UnivariateModifier = GoldenSectionOpt() -) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, modifier=modifier) \ No newline at end of file + bl_modifier::UnivariateModifier = GoldenSectionOpt(), + sort_tree = false, + traversal = Iterators.reverse, + shuffle = false +) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle) diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 46a830e..02c235d 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -147,7 +147,7 @@ function nni_optim_full_traversal!( end end -function nni_optim_iter!( +function nni_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models, @@ -238,109 +238,40 @@ function nni_optim_iter!( end end -function nni_optim!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, - models, - partition_list; - nni_selection_rule = (x) -> argmax(x), -) - - model_list = models(node) - - if isleafnode(node) - return - end - - #This bit of code should be identical to the regular downward pass... - #------------------- - - for part in partition_list - forward!(temp_message[part], node.parent_message[part], model_list[part], node) - end - @assert length(node.children) <= 2 - for i = 1:length(node.children) - new_temp = copy_message(temp_message) #Need to think of how to avoid this allocation. Same as in felsenstein_down - sib_inds = sibling_inds(node.children[i]) - for part in partition_list - combine!( - (node.children[i]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!((node.children[i]).parent_message[part], [temp_message[part]], false) - end - #But calling branchlength_optim recursively... - nni_optim!( - new_temp, - node.child_messages[i], - node.children[i], - models, - partition_list; - nni_selection_rule = nni_selection_rule, - ) - end - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - nnid, exceed_sib, exceed_child = do_nni( - node, - temp_message, - models; - partition_list = partition_list, - nni_selection_rule = nni_selection_rule, - ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(message_to_set[part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end - end -end - #Unsure if this is the best choice to handle the model,models, and model_func stuff. function nni_optim!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, + temp_messages::Vector{Vector{T}}, + tree::FelNode, models::Vector{<:BranchModel}, partition_list; nni_selection_rule = (x) -> argmax(x), -) + traversal = Iterators.reverse, +) where {T <: Partition} nni_optim!( - temp_message, - message_to_set, - node, + temp_messages, + tree, x -> models, partition_list, nni_selection_rule = nni_selection_rule, + traversal = traversal, ) end function nni_optim!( - temp_message::Vector{<:Partition}, - message_to_set::Vector{<:Partition}, - node::FelNode, + temp_messages::Vector{Vector{T}}, + tree::FelNode, model::BranchModel, partition_list; nni_selection_rule = (x) -> argmax(x), -) + traversal = Iterators.reverse, + +) where {T <: Partition} nni_optim!( - temp_message, - message_to_set, - node, + temp_messages, + tree, x -> [model], partition_list, nni_selection_rule = nni_selection_rule, + traversal = traversal, ) end @@ -350,12 +281,11 @@ function do_nni( models::F; partition_list = 1:length(node.message), nni_selection_rule = (x) -> argmax(x), - ) where {F<:Function} - +) where {F<:Function} if length(node.children) == 0 || node.parent === nothing return false else - temp_message2 = copy_message(temp_message) + temp_message2 = copy_message(temp_message) #Make use of temp_messages here model_list = models(node) #current score for part in partition_list @@ -368,11 +298,14 @@ function do_nni( #total_LL(node.message[part]) + #total_LL(node.parent_message[part]) for part in partition_list]) - - changed = false + + change = false nni_LLs = [curr_LL] nni_configs = [(0,0)] + max_LL = -Inf + exceeded, exceed_sib, exceed_child = (false, 0, 0) + for sib_ind in [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] switch_LL = 0.0 @@ -421,21 +354,20 @@ function do_nni( combine!(temp_message[part], [node.parent.parent_message[part]], false) end - LL = sum([total_LL(temp_message[part]) for part in partition_list]) + push!(nni_LLs, LL) push!(nni_configs, (sib_ind, child_ind)) - end end sampled_config_ind = nni_selection_rule(nni_LLs) - changed = sampled_config_ind != 1 + change = sampled_config_ind != 1 (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] - #do the actual move here - if !(changed) + #do the actual move here, switching exceed child and exceed sib + if !(change) return false, sampled_sib_ind, sampled_child_ind else sib = node.parent.children[sampled_sib_ind] @@ -455,57 +387,44 @@ function do_nni( end end -function nni_optim_full_traversal!(tree::FelNode, models; partition_list = nothing, nni_selection_rule = (x) -> argmax(x), sort_tree = false, traversal = Iterators.reverse) - sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed - temp_messages = [copy_message(tree.message)] - - if partition_list === nothing - partition_list = 1:length(tree.message) - end - - nni_optim_full_traversal!(temp_messages, tree, models, partition_list, nni_selection_rule = nni_selection_rule, traversal = traversal) -end - -function nni_optim_iter!(tree::FelNode, models; partition_list = nothing, nni_selection_rule = (x) -> argmax(x), sort_tree = false, traversal = Iterators.reverse) - sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed - temp_messages = [copy_message(tree.message)] - - if partition_list === nothing - partition_list = 1:length(tree.message) - end - - nni_optim_iter!(temp_messages, tree, models, partition_list, nni_selection_rule = nni_selection_rule, traversal = traversal) -end - """ - nni_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5) + nni_optim!(tree::FelNode, models; ) Considers local branch swaps for all branches recursively, maintaining the integrity of the messages. Requires felsenstein!() to have been run first. models can either be a single model (if the messages on the tree contain just one Partition) or an array of models, if the messages have >1 Partition, or a function that takes a node, and returns a Vector{<:BranchModel} if you need the models to vary from one branch to another. -partition_list (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models). -nni_selection_rule lets you choose which nni swap to do (including no swap) based on the log likelihoods of the different nni configurations. + +# Keyword Arguments +- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models, the default option). +- `nni_selection_rule = (x) -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1]. +- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized. +- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable. +- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`. """ function nni_optim!( tree::FelNode, models; partition_list = nothing, nni_selection_rule = (x) -> argmax(x), + sort_tree = false, + traversal = Iterators.reverse, + shuffle = false ) - temp_message = copy_message(tree.message) - message_to_set = copy_message(tree.message) + sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed + temp_messages = [copy_message(tree.message)] if partition_list === nothing partition_list = 1:length(tree.message) end + #Need to decide here between nni_optim and nni_optim_full_traversal nni_optim!( - temp_message, - message_to_set, + temp_messages, tree, models, partition_list, nni_selection_rule = nni_selection_rule, + traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal ) end \ No newline at end of file From e544aac83b40ab22bab4a5d52166ff6319eef21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Mon, 19 Aug 2024 16:07:53 +0200 Subject: [PATCH 08/12] WIP --- src/MolecularEvolution.jl | 3 +++ src/bayes/sampling.jl | 10 +--------- src/utils/simple_sample.jl | 20 +++++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/MolecularEvolution.jl b/src/MolecularEvolution.jl index a282bac..b433ce0 100644 --- a/src/MolecularEvolution.jl +++ b/src/MolecularEvolution.jl @@ -118,6 +118,8 @@ export reroot!, nni_optim!, branchlength_optim!, + metropolis_sample, + copy_tree, #util functions one_hot_sample, @@ -131,6 +133,7 @@ export HKY85, P_from_diagonalized_Q, scale_cols_by_vec!, + BranchlengthSampler, #things the user might overload copy_partition_to!, diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index ec804ab..26cb8a7 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -1,5 +1,3 @@ - -export metropolis_sample """ function metropolis_sample( initial_tree::FelNode, @@ -65,10 +63,6 @@ function metropolis_sample( end - #REMOVE BEFORE PR - if i % 1000 == 0 || i == iterations - println(floor(i/iterations * 100)) - end end if midpoint_rooting @@ -84,8 +78,6 @@ function metropolis_sample( end end - #println("acc_ratio = ", bl_modifier.acc_ratio[1]/sum(bl_modifier.acc_ratio)) - if collect_LLs return samples, sample_LLs end @@ -98,7 +90,7 @@ end """ collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) - Returns a list of distance matrices (containing the distance between the leaf nodes) which can be used to assess mixing. + Returns a list of distance matrices containing the distance between the leaf nodes which can be used to assess mixing. """ function collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) distmats = [] diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index 95de114..27fef0a 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -2,15 +2,6 @@ function univariate_modifier(f, modifier::UnivariateSampler; curr_value=nothing, kwargs...) return univariate_sampler(f, modifier, curr_value) end - -struct BranchlengthSampler <: UnivariateSampler - #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. - acc_ratio - log_bl_proposal - log_bl_prior - BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) -end - """ univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) @@ -20,6 +11,14 @@ function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength return branchlength_metropolis(LL, modifier, curr_branchlength) end +struct BranchlengthSampler <: UnivariateSampler + #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. + acc_ratio + log_bl_proposal + log_bl_prior + BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) +end + function branchlength_metropolis(LL, modifier, curr_value) # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) @@ -36,3 +35,6 @@ function branchlength_metropolis(LL, modifier, curr_value) end end + + + From fc02a8b8db4ebc32d18eb392f84ef34c2b429c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Thu, 22 Aug 2024 13:23:50 +0200 Subject: [PATCH 09/12] WIP --- src/bayes/sampling.jl | 6 +- src/core/algorithms/nni_optim.jl | 117 ++----------------------------- src/core/nodes/FelNode.jl | 1 - src/utils/simple_sample.jl | 18 +++-- 4 files changed, 22 insertions(+), 120 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index 26cb8a7..86c6d82 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -10,7 +10,7 @@ midpoint_rooting=false, ) -Samples tree topologies from a posterior distribution. +Samples tree topologies from a posterior distribution. felsenstein! should be called on the initial tree before calling this function. # Arguments - `initial_tree`: An initial topology with (important!) the leaves populated with data, for the likelihood calculation. @@ -90,7 +90,7 @@ end """ collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) - Returns a list of distance matrices containing the distance between the leaf nodes which can be used to assess mixing. + Returns a list of distance matrices containing the distance between the leaf nodes, which can be used to assess mixing. """ function collect_leaf_dists(trees::Vector{<:AbstractTreeNode}) distmats = [] @@ -103,7 +103,7 @@ end """ leaf_distmat(tree) -Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names. + Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names. """ function leaf_distmat(tree) diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index 02c235d..e1c66a4 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -1,13 +1,3 @@ -#= -About clades getting skipped: -- the iterative implementation perfectly mimics the recursive one (they can both skip clades) -- some nnis can lead to some clades not getting optimized and some getting optimized multiple times -- I could push "every other" during first down and use lastind to know if a clade's been visisted, if a sibling clade's not been visited, I'll simply not fel-up yet but continue down -- -- Sanity checks: compare switch_LL with log_likelihood! of deepcopied tree with said switch -full_traversal passed the sanity check -=# - #After a do_nni, we have to update parent_message if we want to continue down (assume that temp_message is the forwarded parent.parent_message) function update_parent_message!( node::FelNode, @@ -29,7 +19,7 @@ function update_parent_message!( end end -function nni_optim_full_traversal!( +function nni_optim!( temp_messages::Vector{Vector{T}}, tree::FelNode, models, @@ -99,7 +89,7 @@ function nni_optim_full_traversal!( temp_message = pop!(temp_messages) model_list = models(node) if first #We only do_nni first up - nnid, exceed_sib, exceed_child = do_nni( + nnid, sampled_sib_ind, sampled_child_ind = do_nni( node, temp_message, models; @@ -128,7 +118,7 @@ function nni_optim_full_traversal!( end pop!(stack) push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #When we're going up a second time, we no longer need a temp - push!(stack, (temp_message, node, exceed_child, exceed_child, false, true)) #Go to the "new" child - the "new" lastind + push!(stack, (temp_message, node, sampled_child_ind, sampled_child_ind, false, true)) #Go to the "new" child - the "new" lastind continue #Don't fel-up yet end end @@ -147,97 +137,6 @@ function nni_optim_full_traversal!( end end -function nni_optim!( - temp_messages::Vector{Vector{T}}, - tree::FelNode, - models, - partition_list; - nni_selection_rule = (x) -> argmax(x), - traversal = Iterators.reverse -) where {T <: Partition} - - #Consider a NamedTuple/struct - stack = [(pop!(temp_messages), tree, 1, 1, true, true)] - while !isempty(stack) - temp_message, node, ind, lastind, first, down = pop!(stack) - #We start out with a regular downward pass... - #(except for some extra bookkeeping to track if node is visited for the first time) - #------------------- - if isleafnode(node) - push!(temp_messages, temp_message) - continue - end - if down - if first - model_list = models(node) - for part in partition_list - forward!( - temp_message[part], - node.parent_message[part], - model_list[part], - node, - ) - end - @assert length(node.children) <= 2 - #Temp must be constant between iterations for a node during down... - child_iter = traversal(1:length(node.children)) - lastind = Base.first(child_iter) #(which is why we track the last child to be visited during down) - push!(stack, (Vector{T}(), node, ind, lastind, false, false)) #... but not up - for i = child_iter #Iterative reverse <=> Recursive non-reverse, also optimal for lazysort!?? - push!(stack, (temp_message, node, i, lastind, false, true)) - end - end - if !first - sib_inds = sibling_inds(node.children[ind]) - for part in partition_list - combine!( - (node.children[ind]).parent_message[part], - [mess[part] for mess in node.child_messages[sib_inds]], - true, - ) - combine!( - (node.children[ind]).parent_message[part], - [temp_message[part]], - false, - ) - end - #But calling nni_optim! recursively... (the iterative equivalent) - push!(stack, (safepop!(temp_messages, temp_message), node.children[ind], ind, lastind, true, true)) #first + down combination => safepop! - ind == lastind && push!(temp_messages, temp_message) #We no longer need constant temp - end - end - if !down - #Then combine node.child_messages into node.message... - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - end - #But now we need to optimize the current node, and then prop back up to set your parents children message correctly. - #------------------- - if !isroot(node) - temp_message = pop!(temp_messages) - model_list = models(node) - nnid, exceed_sib, exceed_child = do_nni( - node, - temp_message, - models; - partition_list = partition_list, - nni_selection_rule = nni_selection_rule, - ) - for part in partition_list - combine!(node.message[part], [mess[part] for mess in node.child_messages], true) - backward!(node.parent.child_messages[ind][part], node.message[part], model_list[part], node) - combine!( - node.parent.message[part], - [mess[part] for mess in node.parent.child_messages], - true, - ) - end - push!(temp_messages, temp_message) - end - end - end -end - #Unsure if this is the best choice to handle the model,models, and model_func stuff. function nni_optim!( temp_messages::Vector{Vector{T}}, @@ -303,12 +202,12 @@ function do_nni( nni_LLs = [curr_LL] nni_configs = [(0,0)] - max_LL = -Inf - exceeded, exceed_sib, exceed_child = (false, 0, 0) + + for sib_ind in [x for x in 1:length(node.parent.children) if node.parent.children[x] != node] - switch_LL = 0.0 + for child_ind = 1:length(node.children) for part in partition_list #move the sibling message, after upward propogation, to temp_message to work with it @@ -356,7 +255,6 @@ function do_nni( LL = sum([total_LL(temp_message[part]) for part in partition_list]) - push!(nni_LLs, LL) push!(nni_configs, (sib_ind, child_ind)) end @@ -366,7 +264,7 @@ function do_nni( change = sampled_config_ind != 1 (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] - #do the actual move here, switching exceed child and exceed sib + #do the actual move here, switching sampled_child_in and sampled_sib_ind if !(change) return false, sampled_sib_ind, sampled_child_ind else @@ -418,7 +316,6 @@ function nni_optim!( partition_list = 1:length(tree.message) end - #Need to decide here between nni_optim and nni_optim_full_traversal nni_optim!( temp_messages, tree, diff --git a/src/core/nodes/FelNode.jl b/src/core/nodes/FelNode.jl index cbcdcfb..0fcf370 100644 --- a/src/core/nodes/FelNode.jl +++ b/src/core/nodes/FelNode.jl @@ -99,7 +99,6 @@ function mixed_type_equilibrium_message( return out_mess end -export copy_tree """ function copy_tree(root::FelNode, shallow_copy=false) diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index 27fef0a..16b004e 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -2,15 +2,12 @@ function univariate_modifier(f, modifier::UnivariateSampler; curr_value=nothing, kwargs...) return univariate_sampler(f, modifier, curr_value) end -""" - univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) -A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. """ -function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength) - return branchlength_metropolis(LL, modifier, curr_branchlength) -end + BranchlengthSampler + A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[1] stores the number of rejects). +""" struct BranchlengthSampler <: UnivariateSampler #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. acc_ratio @@ -19,6 +16,15 @@ struct BranchlengthSampler <: UnivariateSampler BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) end +""" + univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) + +A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. +""" +function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength) + return branchlength_metropolis(LL, modifier, curr_branchlength) +end + function branchlength_metropolis(LL, modifier, curr_value) # The prior distribution for the variable log(branchlength). A small perturbation of +1e-12 is added to enhance numerical stability near zero. log_prior(x) = logpdf(modifier.log_bl_prior,log(x + 1e-12)) From 185fcae896a18147ea283b4795e0ce67ddb78fa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Thu, 22 Aug 2024 13:32:33 +0200 Subject: [PATCH 10/12] WIP --- src/utils/simple_sample.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index 16b004e..05bca3f 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -9,10 +9,9 @@ end A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[1] stores the number of rejects). """ struct BranchlengthSampler <: UnivariateSampler - #The first entry in acc_ratio holds the number of accepted proposals and the second entry holds the number of rejected proposals. - acc_ratio - log_bl_proposal - log_bl_prior + acc_ratio::Vector{Int} + log_bl_proposal::Distribution + log_bl_prior::Distribution BranchlengthSampler(log_bl_proposal,log_bl_prior) = new([0,0],log_bl_proposal,log_bl_prior) end From fbf74a26bab107edef577e5a27eeefd24229cc44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Thu, 22 Aug 2024 13:44:51 +0200 Subject: [PATCH 11/12] WIP --- src/utils/simple_sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index 05bca3f..eac9c2a 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -18,7 +18,7 @@ end """ univariate_sampler(LL, modifier::BranchlengthPeturbation, curr_branchlength) -A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. + A MCMC algorithm that draws the next sample of a Markov Chain that approximates the Posterior distrubution over the branchlengths. """ function univariate_sampler(LL, modifier::BranchlengthSampler, curr_branchlength) return branchlength_metropolis(LL, modifier, curr_branchlength) From 409c5bad0c8d93ae927805d0fa17a37505947aac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Theodor=20Bj=C3=B6rk?= Date: Sun, 1 Sep 2024 17:22:51 +0200 Subject: [PATCH 12/12] Updated the test and made some minor adjustments based on the feedback --- src/bayes/sampling.jl | 11 +++++++---- src/core/algorithms/branchlength_optim.jl | 4 ++-- src/core/algorithms/nni_optim.jl | 22 +++++++++++----------- src/core/nodes/FelNode.jl | 2 +- src/utils/simple_optim.jl | 2 +- src/utils/simple_sample.jl | 2 +- test/partition_selection.jl | 2 +- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/bayes/sampling.jl b/src/bayes/sampling.jl index 86c6d82..2525fbb 100644 --- a/src/bayes/sampling.jl +++ b/src/bayes/sampling.jl @@ -10,10 +10,10 @@ midpoint_rooting=false, ) -Samples tree topologies from a posterior distribution. felsenstein! should be called on the initial tree before calling this function. +Samples tree topologies from a posterior distribution. # Arguments -- `initial_tree`: An initial topology with (important!) the leaves populated with data, for the likelihood calculation. +- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation. - `models`: A list of branch models. - `num_of_samples`: The number of tree samples drawn from the posterior. - `bl_sampler`: Sampler used to drawn branchlengths from the posterior. @@ -22,8 +22,11 @@ Samples tree topologies from a posterior distribution. felsenstein! should be ca - `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees. - `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium). +!!! note + The leaves of the initial tree should be populated with data and felsenstein! should be called on the initial tree before calling this function. + # Returns -- `samples`: The trees drawn from the posterior. +- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc. - `sample_LLs`: The associated log-likelihoods of the tree (optional). """ function metropolis_sample( @@ -50,7 +53,7 @@ function metropolis_sample( for i=1:iterations # Updates the tree topolgy and branchlengths. - nni_optim!(tree, x -> models, nni_selection_rule = softmax_sampler) + nni_optim!(tree, x -> models, selection_rule = softmax_sampler) branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler) if (i-burn_in) % sample_interval == 0 && i > burn_in diff --git a/src/core/algorithms/branchlength_optim.jl b/src/core/algorithms/branchlength_optim.jl index 79c95d0..fb20b21 100644 --- a/src/core/algorithms/branchlength_optim.jl +++ b/src/core/algorithms/branchlength_optim.jl @@ -87,7 +87,7 @@ function branchlength_optim!( temp_message = pop!(temp_messages) model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength) if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) node.branchlength = bl end @@ -110,7 +110,7 @@ function branchlength_optim!( #------------------- model_list = models(node) fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list) - bl = univariate_modifier(fun, bl_modifier; a=0, b=1, tol=tol, transform=unit_transform, curr_value=node.branchlength) + bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength) if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt) node.branchlength = bl end diff --git a/src/core/algorithms/nni_optim.jl b/src/core/algorithms/nni_optim.jl index e1c66a4..ec7a39c 100644 --- a/src/core/algorithms/nni_optim.jl +++ b/src/core/algorithms/nni_optim.jl @@ -24,7 +24,7 @@ function nni_optim!( tree::FelNode, models, partition_list; - nni_selection_rule = (x) -> argmax(x), + selection_rule = x -> argmax(x), traversal = Iterators.reverse ) where {T <: Partition} @@ -94,7 +94,7 @@ function nni_optim!( temp_message, models; partition_list = partition_list, - nni_selection_rule = nni_selection_rule, + selection_rule = selection_rule, ) if nnid && last(last(stack)) #We nnid a sibling that hasn't been visited (then, down would be true in the next iter)... #... and now we want to continue down the nnid sibling (now a child to node) @@ -143,7 +143,7 @@ function nni_optim!( tree::FelNode, models::Vector{<:BranchModel}, partition_list; - nni_selection_rule = (x) -> argmax(x), + selection_rule = x -> argmax(x), traversal = Iterators.reverse, ) where {T <: Partition} nni_optim!( @@ -151,7 +151,7 @@ function nni_optim!( tree, x -> models, partition_list, - nni_selection_rule = nni_selection_rule, + selection_rule = selection_rule, traversal = traversal, ) end @@ -160,7 +160,7 @@ function nni_optim!( tree::FelNode, model::BranchModel, partition_list; - nni_selection_rule = (x) -> argmax(x), + selection_rule = x -> argmax(x), traversal = Iterators.reverse, ) where {T <: Partition} @@ -169,7 +169,7 @@ function nni_optim!( tree, x -> [model], partition_list, - nni_selection_rule = nni_selection_rule, + selection_rule = selection_rule, traversal = traversal, ) end @@ -179,7 +179,7 @@ function do_nni( temp_message, models::F; partition_list = 1:length(node.message), - nni_selection_rule = (x) -> argmax(x), + selection_rule = x -> argmax(x), ) where {F<:Function} if length(node.children) == 0 || node.parent === nothing return false @@ -260,7 +260,7 @@ function do_nni( end end - sampled_config_ind = nni_selection_rule(nni_LLs) + sampled_config_ind = selection_rule(nni_LLs) change = sampled_config_ind != 1 (sampled_sib_ind, sampled_child_ind) = nni_configs[sampled_config_ind] @@ -295,7 +295,7 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th # Keyword Arguments - `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize tree topology with all models, the default option). -- `nni_selection_rule = (x) -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1]. +- `selection_rule = x -> argmax(x)`: a function that takes the current and proposed log likelihoods and selects a nni configuration. Note that the current log likelihood is stored at x[1]. - `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized. - `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable. - `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`. @@ -304,7 +304,7 @@ function nni_optim!( tree::FelNode, models; partition_list = nothing, - nni_selection_rule = (x) -> argmax(x), + selection_rule = x -> argmax(x), sort_tree = false, traversal = Iterators.reverse, shuffle = false @@ -321,7 +321,7 @@ function nni_optim!( tree, models, partition_list, - nni_selection_rule = nni_selection_rule, + selection_rule = selection_rule, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal ) end \ No newline at end of file diff --git a/src/core/nodes/FelNode.jl b/src/core/nodes/FelNode.jl index 0fcf370..97119d7 100644 --- a/src/core/nodes/FelNode.jl +++ b/src/core/nodes/FelNode.jl @@ -102,7 +102,7 @@ end """ function copy_tree(root::FelNode, shallow_copy=false) -Returns a untangled copy of the a tree. Optionally, the flag `shallow_copy` can be used to obtained a copy of the tree with only the names and branchlengths. + Returns an untangled copy of the tree. Optionally, the flag `shallow_copy` can be used to obtain a copy of the tree with only the names and branchlengths. """ function copy_tree(root::FelNode, shallow_copy=false) diff --git a/src/utils/simple_optim.jl b/src/utils/simple_optim.jl index 4945a34..f559f1d 100644 --- a/src/utils/simple_optim.jl +++ b/src/utils/simple_optim.jl @@ -16,7 +16,7 @@ struct GoldenSectionOpt <: UnivariateOpt end struct BrentsMethodOpt <: UnivariateOpt end function univariate_modifier(fun, modifier::UnivariateOpt; a=0, b=1, transform=unit_transform, tol=10e-5, kwargs...) - return univariate_maximize(fun, a + tol, b - tol, unit_transform, modifier, tol) + return univariate_maximize(fun, a, b, unit_transform, modifier, tol) end """ diff --git a/src/utils/simple_sample.jl b/src/utils/simple_sample.jl index eac9c2a..56103c7 100644 --- a/src/utils/simple_sample.jl +++ b/src/utils/simple_sample.jl @@ -6,7 +6,7 @@ end """ BranchlengthSampler - A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[1] stores the number of rejects). + A type that allows you to specify a additive proposal function in the log domain and a prior distrubution over the log of the branchlengths. It also holds the acceptance ratio acc_ratio (acc_ratio[1] stores the number of accepts, and acc_ratio[2] stores the number of rejects). """ struct BranchlengthSampler <: UnivariateSampler acc_ratio::Vector{Int} diff --git a/test/partition_selection.jl b/test/partition_selection.jl index 16bf0de..fb7fbb9 100644 --- a/test/partition_selection.jl +++ b/test/partition_selection.jl @@ -57,7 +57,7 @@ begin branchlength_optim!(tree, bm_models, partition_list = [1]) branchlength_optim!(tree, bm_models, partition_list = [2]) branchlength_optim!(tree, bm_models) - branchlength_optim!(tree, bm_models, bl_optimizer=BrentsMethodOpt()) + branchlength_optim!(tree, bm_models, bl_modifier=BrentsMethodOpt()) branchlength_optim!(tree, x -> bm_models, partition_list = [2]) branchlength_optim!(tree, x -> bm_models)