Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend optimization methods #37

Merged
merged 13 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/MolecularEvolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ abstract type SimulationModel <: BranchModel end #Simulation models typically ca

abstract type StatePath end

abstract type UnivariateOpt end
abstract type UnivariateModifier end
abstract type UnivariateOpt <: UnivariateModifier end
abstract type UnivariateSampler <: UnivariateModifier end

abstract type LazyDirection end

Expand All @@ -39,6 +41,7 @@ include("core/algorithms/algorithms.jl")
include("core/sim_tree.jl")
include("models/models.jl")
include("utils/utils.jl")
include("bayes/bayes.jl")

#Optional dependencies
function __init__()
Expand Down Expand Up @@ -115,6 +118,8 @@ export
reroot!,
nni_optim!,
branchlength_optim!,
metropolis_sample,
copy_tree,

#util functions
one_hot_sample,
Expand All @@ -128,6 +133,7 @@ export
HKY85,
P_from_diagonalized_Q,
scale_cols_by_vec!,
BranchlengthSampler,

#things the user might overload
copy_partition_to!,
Expand Down
1 change: 1 addition & 0 deletions src/bayes/bayes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("sampling.jl")
124 changes: 124 additions & 0 deletions src/bayes/sampling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
function metropolis_sample(
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_modifier::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1))
burn_in=1000,
sample_interval=10,
collect_LLs = false,
midpoint_rooting=false,
)

Samples tree topologies from a posterior distribution.

# Arguments
- `initial_tree`: An initial tree topology with the leaves populated with data, for the likelihood calculation.
- `models`: A list of branch models.
- `num_of_samples`: The number of tree samples drawn from the posterior.
- `bl_sampler`: Sampler used to drawn branchlengths from the posterior.
- `burn_in`: The number of samples discarded at the start of the Markov Chain.
- `sample_interval`: The distance between samples in the underlying Markov Chain (to reduce sample correlation).
- `collect_LLs`: Specifies if the function should return the log-likelihoods of the trees.
- `midpoint_rooting`: Specifies whether the drawn samples should be midpoint rerooted (Important! Should only be used for time-reversible branch models starting in equilibrium).

!!! note
The leaves of the initial tree should be populated with data and felsenstein! should be called on the initial tree before calling this function.

# Returns
- `samples`: The trees drawn from the posterior. Returns shallow tree copies, which needs to be repopulated before running felsenstein! etc.
- `sample_LLs`: The associated log-likelihoods of the tree (optional).
"""
function metropolis_sample(
initial_tree::FelNode,
models::Vector{<:BranchModel},
num_of_samples;
bl_sampler::UnivariateSampler = BranchlengthSampler(Normal(0,2), Normal(-1,1)),
burn_in=1000,
sample_interval=10,
collect_LLs = false,
midpoint_rooting=false,
ladderize = false,
)

# The prior over the (log) of the branchlengths should be specified in bl_sampler.
# Furthermore, a non-informative/uniform prior is assumed over the tree topolgies (excluding the branchlengths).

sample_LLs = []
samples = FelNode[]
tree = deepcopy(initial_tree)
iterations = burn_in + num_of_samples * sample_interval

softmax_sampler = x -> rand(Categorical(softmax(x)))
nossleinad marked this conversation as resolved.
Show resolved Hide resolved
for i=1:iterations

# Updates the tree topolgy and branchlengths.
nni_optim!(tree, x -> models, selection_rule = softmax_sampler)
branchlength_optim!(tree, x -> models, bl_modifier = bl_sampler)

if (i-burn_in) % sample_interval == 0 && i > burn_in

push!(samples, copy_tree(tree, true))

if collect_LLs
push!(sample_LLs, log_likelihood!(tree, models))
end

end

end

if midpoint_rooting
for (i,sample) in enumerate(samples)
node, len = midpoint(sample)
samples[i] = reroot!(node, dist_above_child=len)
end
end

if ladderize
for sample in samples
ladderize!(sample)
end
end

if collect_LLs
return samples, sample_LLs
end

return samples
end

# Below are some functions that help to assess the mixing by looking at the distance between leaf nodes.

"""
collect_leaf_dists(trees::Vector{<:AbstractTreeNode})

Returns a list of distance matrices containing the distance between the leaf nodes, which can be used to assess mixing.
"""
function collect_leaf_dists(trees::Vector{<:AbstractTreeNode})
distmats = []
for tree in trees
push!(distmats, leaf_distmat(tree))
end
return distmats
end

"""
leaf_distmat(tree)

Returns a matrix of the distances between the leaf nodes where the index on the columns and rows are sorted by the leaf names.
"""
function leaf_distmat(tree)

distmat, node_dic = MolecularEvolution.tree2distances(tree)

leaflist = getleaflist(tree)

sort!(leaflist, by = x-> x.name)

order = [node_dic[leaf] for leaf in leaflist]

return distmat[order, order]
end


31 changes: 16 additions & 15 deletions src/core/algorithms/branchlength_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function branchlength_optim!(
models,
partition_list,
tol;
bl_optimizer::UnivariateOpt = GoldenSectionOpt(),
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
traversal = Iterators.reverse
) where {T <: Partition}

Expand Down Expand Up @@ -87,10 +87,11 @@ function branchlength_optim!(
temp_message = pop!(temp_messages)
model_list = models(node)
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol)
if fun(opt) > fun(node.branchlength)
node.branchlength = opt
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
node.branchlength = bl
end

#Consider checking for improvement, and bailing if none.
#Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one]
for part in partition_list
Expand All @@ -109,9 +110,9 @@ function branchlength_optim!(
#-------------------
model_list = models(node)
fun = x -> branch_LL_up(x, temp_message, node, model_list, partition_list)
opt = univariate_maximize(fun, 0 + tol, 1 - tol, unit_transform, bl_optimizer, tol)
if fun(opt) > fun(node.branchlength)
node.branchlength = opt
bl = univariate_modifier(fun, bl_modifier; a=0+tol, b=1-tol, tol=tol, transform=unit_transform, curr_value=node.branchlength)
if fun(bl) > fun(node.branchlength) || !(bl_modifier isa UnivariateOpt)
node.branchlength = bl
end
#Consider checking for improvement, and bailing if none.
#Then we need to set the "message_to_set", which is node.parent.child_messages[but_the_right_one]
Expand Down Expand Up @@ -139,21 +140,21 @@ a function that takes a node, and returns a Vector{<:BranchModel} if you need th

# Keyword Arguments
- `partition_list=nothing`: (eg. 1:3 or [1,3,5]) lets you choose which partitions to run over (but you probably want to optimize branch lengths with all models, the default option).
- `tol=1e-5`: absolute tolerance for the `bl_optimizer`.
- `bl_optimizer=GoldenSectionOpt()`: univariate branchlength optimizer, has Brent's method as an option by setting bl_optimizer=BrentsMethodOpt().
- `tol=1e-5`: absolute tolerance for the `bl_modifier`.
- `bl_modifier=GoldenSectionOpt()`: can either be a optimizer or a sampler (subtype of UnivariateModifier). For optimization, in addition to golden section search, Brent's method can be used by setting bl_modifier=BrentsMethodOpt().
- `sort_tree=false`: determines if a [`lazysort!`](@ref) will be performed, which can reduce the amount of temporary messages that has to be initialized.
- `traversal=Iterators.reverse`: a function that determines the traversal, permutes an iterable.
- `shuffle=false`: do a randomly shuffled traversal, overrides `traversal`.
"""
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_optimizer::UnivariateOpt = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false)
function branchlength_optim!(tree::FelNode, models; partition_list = nothing, tol = 1e-5, bl_modifier::UnivariateModifier = GoldenSectionOpt(), sort_tree = false, traversal = Iterators.reverse, shuffle = false)
sort_tree && lazysort!(tree) #A lazysorted tree minimizes the amount of temp_messages needed
temp_messages = [copy_message(tree.message)]

if partition_list === nothing
partition_list = 1:length(tree.message)
end

branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_optimizer=bl_optimizer, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
branchlength_optim!(temp_messages, tree, models, partition_list, tol, bl_modifier=bl_modifier, traversal = shuffle ? x -> sample(x, length(x), replace=false) : traversal)
end

#Overloading to allow for direct model and model vec inputs
Expand All @@ -162,18 +163,18 @@ branchlength_optim!(
models::Vector{<:BranchModel};
partition_list = nothing,
tol = 1e-5,
bl_optimizer::UnivariateOpt = GoldenSectionOpt(),
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
sort_tree = false,
traversal = Iterators.reverse,
shuffle = false
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_optimizer = bl_optimizer, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
) = branchlength_optim!(tree, x -> models, partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
branchlength_optim!(
tree::FelNode,
model::BranchModel;
partition_list = nothing,
tol = 1e-5,
bl_optimizer::UnivariateOpt = GoldenSectionOpt(),
bl_modifier::UnivariateModifier = GoldenSectionOpt(),
sort_tree = false,
traversal = Iterators.reverse,
shuffle = false
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_optimizer = bl_optimizer, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
) = branchlength_optim!(tree, x -> [model], partition_list = partition_list, tol = tol, bl_modifier = bl_modifier, sort_tree = sort_tree, traversal = traversal, shuffle = shuffle)
Loading
Loading