Skip to content

Commit

Permalink
More tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Dec 8, 2023
1 parent b398357 commit 0979801
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 125 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ pages = [
"Home" => "index.md",
"Tutorials" => [
"Basics" => joinpath("examples", "basics.md"),
"Distributions" => joinpath("examples", "distributions.md"),
"Interfaces" => joinpath("examples", "interfaces.md"),
"Autodiff" => joinpath("examples", "autodiff.md"),
"Periodic" => joinpath("examples", "periodic.md"),
"Controlled" => joinpath("examples", "controlled.md"),
],
Expand Down
59 changes: 59 additions & 0 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# # Autodiff

#=
Here we show how to compute gradients of the observation sequence loglikelihood with respect to various parameters.
=#

using DensityInterface
using Distributions
using Enzyme
using ForwardDiff
using HiddenMarkovModels
using HiddenMarkovModels: test_coherent_algorithms #src
using LinearAlgebra
using Random: Random, AbstractRNG
using StatsAPI
using Test #src

#-

rng = Random.default_rng()
Random.seed!(rng, 63);

# ## Forward mode

#=
Since all of our code is type-generic, it is amenable to forward-mode automatic differentiation with ForwardDiff.jl.
=#

init = [0.8, 0.2]
trans = [0.7 0.3; 0.3 0.7]
means = [-1.0, 1.0]
dists = Normal.(means)
hmm = HMM(init, trans, dists);

_, obs_seq = rand(rng, hmm, 10);

#-

f1(new_init) = logdensityof(HMM(new_init, trans, dists), obs_seq)
ForwardDiff.gradient(f1, init)

#-

f2(new_trans) = logdensityof(HMM(init, new_trans, dists), obs_seq)
ForwardDiff.gradient(f2, trans)

#-

f3(new_means) = logdensityof(HMM(init, trans, Normal.(new_means)), obs_seq)
ForwardDiff.gradient(f3, means)

# ## Reverse mode

#=
In the presence of many parameters, reverse mode automatic differentiation of the loglikelihood will be much more efficient.
This requires using Enzyme.jl and the mutating `forward!` function.
=#

# ## Gradient descent for estimation
6 changes: 5 additions & 1 deletion examples/basics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# # Basics

#=
Here, we show how to use the essential ingredients of the package.
=#

using Distributions
using HiddenMarkovModels
using HiddenMarkovModels: test_coherent_algorithms #src
Expand Down Expand Up @@ -195,5 +199,5 @@ baum_welch(hmm_guess, obs_seq_concat; seq_ends);

# ## Tests #src

control_seq, seq_ends = fill(nothing, 1000), 10:10:1000 #src
control_seq, seq_ends = fill(nothing, 1000), 100:10:1000 #src
test_coherent_algorithms(rng, hmm, hmm_guess; control_seq, seq_ends, atol=1e-1) #src
2 changes: 1 addition & 1 deletion examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);

#-

control_seqs = [[randn(rng, 3) for t in 1:rand(100:200)] for k in 1:100];
control_seqs = [[randn(rng, 3) for t in 1:rand(100:200)] for k in 1:10];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
Expand Down
112 changes: 0 additions & 112 deletions examples/distributions.jl

This file was deleted.

Loading

0 comments on commit 0979801

Please sign in to comment.