Skip to content

Commit

Permalink
Give PR write permissions to benchmark workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 24, 2024
1 parent 559b6dc commit 2746f09
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ on:
jobs:
Benchmark:
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
if: contains(github.event.pull_request.labels.*.name, 'run benchmark')
steps:
- uses: actions/checkout@v2
Expand Down
16 changes: 8 additions & 8 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -92,7 +92,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10
end evals = 1 samples = 100
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
Expand All @@ -107,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/dynamax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables(
filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0))))
benchs["forward"] = @benchmarkable begin
$(filter_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
Expand All @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["viterbi"] = @benchmarkable begin
$(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
Expand All @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables(
)
benchs["forward_backward"] = @benchmarkable begin
$(smoother_vmap)($dyn_params, $obs_tens_jax_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
Expand All @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables(
num_iters=$bw_iter,
verbose=false,
)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
tup = build_model($implem, $instance, $params);
hmm_guess = tup[1];
dyn_params_guess = tup[2];
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ function HMMBenchmark.build_benchmarkables(
@threads for k in eachindex($obs_mats)
HMMBase.forward($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.viterbi($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
@threads for k in eachindex($obs_mats)
HMMBase.posteriors($hmm, $obs_mats[k])
end
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

return benchs
Expand Down
8 changes: 4 additions & 4 deletions libs/HMMComparison/src/hmmlearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,25 @@ function HMMBenchmark.build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
$(hmm.score)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
$(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
$(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($implem, $instance, $params)
)
end
Expand Down
6 changes: 3 additions & 3 deletions libs/HMMComparison/src/pomegranate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ function HMMBenchmark.build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
$(hmm.forward)($obs_tens_torch_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
$(hmm.forward_backward)($obs_tens_torch_py)
end evals = 1 samples = 10
end evals = 1 samples = 100
end

if "baum_welch" in algos
benchs["baum_welch"] = @benchmarkable begin
hmm_guess.fit($obs_tens_torch_py)
end evals = 1 samples = 10 setup = (
end evals = 1 samples = 100 setup = (
hmm_guess = build_model($implem, $instance, $params)
)
end
Expand Down

0 comments on commit 2746f09

Please sign in to comment.