Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up viterbi #93

Merged
merged 8 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ on:
jobs:
Benchmark:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
if: contains(github.event.pull_request.labels.*.name, 'run benchmark')
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.4.1"
version = "0.5.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Then, you can create your first model as follows:

```julia
using Distributions, HiddenMarkovModels
init = [0.4, 0.6]
trans = [0.9 0.1; 0.2 0.8]
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dists = [Normal(-1.0), Normal(1.0)]
hmm = HMM(init, trans, dists)
```
Expand Down
2 changes: 1 addition & 1 deletion benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ version = "0.1.0"
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.4.0"
version = "0.5.0"
weakdeps = ["Distributions"]

[deps.HiddenMarkovModels.extensions]
Expand Down
14 changes: 12 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,24 @@ HiddenMarkovModels.forward_backward!
HiddenMarkovModels.baum_welch!
```

## Misc
## Miscellaneous

```@docs
HiddenMarkovModels.valid_hmm
HiddenMarkovModels.rand_prob_vec
HiddenMarkovModels.rand_trans_mat
HiddenMarkovModels.fit_in_sequence!
```

## Internals

```@docs
HiddenMarkovModels.LightDiagNormal
HiddenMarkovModels.LightCategorical
HiddenMarkovModels.fit_in_sequence!
HiddenMarkovModels.log_initialization
HiddenMarkovModels.log_transition_matrix
HiddenMarkovModels.mul_rows_cols!
HiddenMarkovModels.argmaxplus_transmul!
```

## Index
Expand Down
12 changes: 6 additions & 6 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ rng = StableRNG(63);
# ## Model

#=
The package provides a versatile [`HMM`](@ref) type with three attributes:
- a vector of state initialization probabilities
- a matrix of state transition probabilities
- a vector of observation distributions, one for each state
The package provides a versatile [`HMM`](@ref) type with three main attributes:
- a vector `init` of state initialization probabilities
- a matrix `trans` of state transition probabilities
- a vector `dists` of observation distributions, one for each state

Any scalar- or vector-valued distribution from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) can be used for the last part, as well as [Custom distributions](@ref).
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)]
hmm = HMM(init, trans, dists)

Expand Down Expand Up @@ -142,7 +142,7 @@ Since it is a local optimization procedure, it requires a starting point that is
=#

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.4 0.6]
trans_guess = [0.6 0.4; 0.3 0.7]
dists_guess = [MvNormal([-0.4, -0.7], I), MvNormal([0.4, 0.7], I)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

Expand Down
17 changes: 9 additions & 8 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ struct ControlledGaussianHMM{T} <: AbstractHMM
end

#=
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one.
=#

function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM)
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
return hmm.trans
end

Expand All @@ -54,8 +55,8 @@ In this case, the transition matrix does not depend on the control.
# ## Simulation

d = 3
init = [0.8, 0.2]
trans = [0.7 0.3; 0.3 0.7]
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);

Expand Down Expand Up @@ -122,9 +123,9 @@ end
Now we put it to the test.
=#

init_guess = [0.7, 0.3]
trans_guess = [0.6 0.4; 0.4 0.6]
dist_coeffs_guess = [-0.7 * ones(d), 0.7 * ones(d)]
init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-1.1 * ones(d), 1.1 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);

#-
Expand All @@ -136,7 +137,7 @@ first(loglikelihood_evolution), last(loglikelihood_evolution)
How did we perform?
=#

cat(transition_matrix(hmm_est), transition_matrix(hmm); dims=3)
cat(hmm_est.trans, hmm.trans; dims=3)

#-

Expand Down
6 changes: 3 additions & 3 deletions examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Let's put it to the test.
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [StuffDist(-1.0), StuffDist(+1.0)]
hmm = HMM(init, trans, dists);

Expand All @@ -104,8 +104,8 @@ If we implement `fit!`, Baum-Welch also works seamlessly.
=#

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.4 0.6]
dists_guess = [StuffDist(-0.7), StuffDist(+0.7)]
trans_guess = [0.6 0.4; 0.3 0.7]
dists_guess = [StuffDist(-1.1), StuffDist(+1.1)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

#-
Expand Down
6 changes: 3 additions & 3 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
# ## Simulation

init = [0.6, 0.4]
trans_per = ([0.7 0.3; 0.3 0.7], [0.3 0.7; 0.7 0.3])
trans_per = ([0.7 0.3; 0.2 0.8], [0.3 0.7; 0.8 0.2])
dists_per = ([Normal(-1.0), Normal(-2.0)], [Normal(+1.0), Normal(+2.0)])
hmm = PeriodicHMM(init, trans_per, dists_per);

Expand Down Expand Up @@ -152,8 +152,8 @@ Now let's test our procedure with a reasonable guess.
=#

init_guess = [0.7, 0.3]
trans_per_guess = ([0.6 0.4; 0.4 0.6], [0.4 0.6; 0.6 0.4])
dists_per_guess = ([Normal(-0.7), Normal(-1.7)], [Normal(+0.7), Normal(+1.7)])
trans_per_guess = ([0.6 0.4; 0.3 0.7], [0.4 0.6; 0.7 0.3])
dists_per_guess = ([Normal(-1.1), Normal(-2.1)], [Normal(+1.1), Normal(+2.1)])
hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess);

#=
Expand Down
15 changes: 11 additions & 4 deletions examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Here we explain why playing with different number and array types can be useful

using Distributions
using HiddenMarkovModels
using HiddenMarkovModels: log_transition_matrix #src
using HMMTest #src
using LinearAlgebra
using LogarithmicNumbers
Expand Down Expand Up @@ -40,7 +41,7 @@ To give an example, let us first generate some data from a vanilla HMM.
=#

init = [0.6, 0.4]
trans = [0.7 0.3; 0.3 0.7]
trans = [0.7 0.3; 0.2 0.8]
dists = [Normal(-1.0), Normal(1.0)]
hmm = HMM(init, trans, dists)
state_seq, obs_seq = rand(rng, hmm, 100);
Expand All @@ -57,6 +58,10 @@ hmm_uncertain = HMM(init, trans, dists_guess)
Every quantity we compute with this new HMM will have propagated uncertainties around it.
=#

logdensityof(hmm, obs_seq)

#-

logdensityof(hmm_uncertain, obs_seq)

#=
Expand Down Expand Up @@ -129,7 +134,7 @@ trans_guess = sparse([
0 0.6 0.4
0.4 0 0.6
])
dists_guess = [Normal(1.2), Normal(2.2), Normal(3.2)]
dists_guess = [Normal(1.1), Normal(2.1), Normal(3.1)]
hmm_guess = HMM(init_guess, trans_guess, dists_guess);

#-
Expand All @@ -149,10 +154,12 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St

# ## Tests #src

@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src

seq_ends = cumsum(rand(rng, 100:200, 100)); #src
control_seqs = fill(nothing, length(seq_ends)); #src
control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
16 changes: 8 additions & 8 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -92,7 +92,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
Expand All @@ -107,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/dynamax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables(
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["forward"] = @benchmarkable begin
$(filter_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
Expand All @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["viterbi"] = @benchmarkable begin
$(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
Expand All @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["forward_backward"] = @benchmarkable begin
$(smoother_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
Expand All @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables(
num_iters=$bw_iter,
verbose=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
tup = build_model($implem, $instance, $params);
hmm_guess = tup[1];
dyn_params_guess = tup[2];
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ function HMMBenchmark.build_benchmarkables(
@threads for k in eachindex($obs_mats)
HMMBase.forward($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.viterbi($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.posteriors($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

return benchs
Expand Down
Loading
Loading