Skip to content

Commit

Permalink
Fix benchmarks (#90)
Browse files Browse the repository at this point in the history
* Fix benchmarks

* Don't ignore png

* Remove JOSS

* Commit benchmark manifest

* Remove Pkg business

* Fix paths
  • Loading branch information
gdalle authored Feb 22, 2024
1 parent f4b750d commit 5ad5d81
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 139 deletions.
26 changes: 0 additions & 26 deletions .github/workflows/draft-pdf.yml

This file was deleted.

3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,4 @@ scratchpad.jl
/docs/src/index.md
/docs/src/examples/*.md

*.pdf
*.png
*.pdf
35 changes: 20 additions & 15 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.0"
julia_version = "1.10.1"
manifest_format = "2.0"
project_hash = "e05ed926575e94b72904ad898b09f017dc14d96a"

[[deps.ArgCheck]]
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
version = "2.3.0"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
Expand Down Expand Up @@ -34,9 +39,9 @@ version = "0.5.1"

[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "1287e3872d646eed95198457873249bd9f0caed2"
git-tree-sha1 = "892b245fdec1c511906671b6a5e1bafa38a727c1"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.20.1"
version = "1.22.0"
weakdeps = ["SparseArrays"]

[deps.ChainRulesCore.extensions]
Expand All @@ -50,9 +55,9 @@ version = "0.7.4"

[[deps.Compat]]
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b"
git-tree-sha1 = "d2c021fbdde94f6cdaa799639adfeeaa17fd67f5"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.12.0"
version = "4.13.0"
weakdeps = ["Dates", "LinearAlgebra"]

[deps.Compat.extensions]
Expand All @@ -61,7 +66,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+1"
version = "1.1.0+0"

[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
Expand All @@ -81,9 +86,9 @@ version = "1.6.1"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
git-tree-sha1 = "1fb174f0d48fe7d142e1109a10636bc1d14f5ac2"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.16"
version = "0.18.17"

[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand Down Expand Up @@ -155,13 +160,13 @@ deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.HMMBenchmark]]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"]
path = "../libs/HMMBenchmark"
uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179"
version = "0.1.0"

[[deps.HiddenMarkovModels]]
deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI"]
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
path = ".."
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.4.0"
Expand Down Expand Up @@ -257,9 +262,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.26"
version = "0.3.27"

[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -309,7 +314,7 @@ version = "1.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+2"
version = "0.3.23+4"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -469,9 +474,9 @@ version = "0.34.2"

[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a"
git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.3.0"
version = "1.3.1"
weakdeps = ["ChainRulesCore", "InverseFunctions"]

[deps.StatsFuns.extensions]
Expand Down
1 change: 1 addition & 0 deletions libs/HMMBenchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 3 additions & 1 deletion libs/HMMBenchmark/src/HMMBenchmark.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMMBenchmark

using Base.Threads
using Base.Threads: Threads
using InteractiveUtils: InteractiveUtils
using BenchmarkTools: @benchmarkable, BenchmarkGroup
using CSV: CSV
using DataFrames: DataFrame
Expand Down Expand Up @@ -33,5 +34,6 @@ include("instance.jl")
include("params.jl")
include("hiddenmarkovmodels.jl")
include("suite.jl")
include("setup.jl")

end
32 changes: 20 additions & 12 deletions libs/HMMBenchmark/src/hiddenmarkovmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ function build_model(::HiddenMarkovModelsImplem, instance::Instance, params::Par
(; custom_dist, nb_states, obs_dim) = instance
(; init, trans, means, stds) = params

if custom_dist
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
if obs_dim == 1
dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states]
else
if obs_dim == 1
dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states]
if custom_dist
dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states]
else
dists = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states]
end
Expand Down Expand Up @@ -43,32 +43,38 @@ function build_benchmarkables(
if "forward" in algos
benchs["forward"] = @benchmarkable begin
forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "forward!" in algos
benchs["forward!"] = @benchmarkable begin
forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "viterbi" in algos
benchs["viterbi"] = @benchmarkable begin
viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "viterbi!" in algos
benchs["viterbi!"] = @benchmarkable begin
viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
)
end

if "forward_backward" in algos
benchs["forward_backward"] = @benchmarkable begin
forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "forward_backward!" in algos
benchs["forward_backward!"] = @benchmarkable begin
forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
fb_storage = initialize_forward_backward(
$hmm, $obs_seq, $control_seq; seq_ends=$seq_ends
)
Expand All @@ -86,7 +92,9 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100
end evals = 1 samples = 10
end
if "baum_welch!" in algos
benchs["baum_welch!"] = @benchmarkable begin
baum_welch!(
fb_storage,
Expand All @@ -99,7 +107,7 @@ function build_benchmarkables(
atol=-Inf,
loglikelihood_increasing=false,
)
end evals = 1 samples = 100 setup = (
end evals = 1 samples = 10 setup = (
hmm_guess = build_model($implem, $instance, $params);
fb_storage = initialize_forward_backward(
hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends
Expand Down
4 changes: 2 additions & 2 deletions libs/HMMBenchmark/src/setup.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function print_julia_setup(; path)
function print_julia_setup(path)
open(path, "w") do file
redirect_stdout(file) do
versioninfo()
InteractiveUtils.versioninfo()
println("\n# Multithreading\n")
println("Julia threads = $(Threads.nthreads())")
println("OpenBLAS threads = $(BLAS.get_num_threads())")
Expand Down
62 changes: 62 additions & 0 deletions libs/HMMComparison/experiments/performance.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Pkg
Pkg.activate(joinpath(@__DIR__, ".."))
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "HMMBenchmark"))

@assert Base.Threads.nthreads() == 1

# see https://superfastpython.com/numpy-number-blas-threads/
ENV["MKL_NUM_THREADS"] = 1
ENV["NUMEXPR_NUM_THREADS"] = 1
ENV["OMP_NUM_THREADS"] = 1
ENV["OPENBLAS_NUM_THREADS"] = 1
ENV["VECLIB_MAXIMUM_THREADS"] = 1

# see https://github.com/google/jax/issues/743
ENV["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"

using BenchmarkTools
using LinearAlgebra
using PythonCall # Python process starts now
using StableRNGs
using HMMComparison

# see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html
pyimport("torch").set_num_threads(1)

rng = StableRNG(63)

print_julia_setup(joinpath(@__DIR__, "results", "julia_setup.txt"))
print_python_setup(joinpath(@__DIR__, "results", "python_setup.txt"))

implems = [
HiddenMarkovModelsImplem(), #
HMMBaseImplem(), #
hmmlearnImplem(), #
pomegranateImplem(), #
dynamaxImplem(), #
]

algos = ["forward", "viterbi", "forward_backward", "baum_welch"]

instances = Instance[]

for nb_states in 2:2:16
push!(
instances,
Instance(;
custom_dist=true,
sparse=false,
nb_states=nb_states,
obs_dim=1,
seq_length=200,
nb_seqs=100,
bw_iter=5,
),
)
end

SUITE = define_suite(rng, implems; instances, algos)

results = BenchmarkTools.run(SUITE; verbose=true)
data = parse_results(results; path=joinpath(@__DIR__, "results", "results.csv"))
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using DataFrames
using Plots
using HMMComparison

data = read_results(joinpath(@__DIR__, "results.csv"))
data = read_results(joinpath(@__DIR__, "results", "results.csv"))

sort!(data, [:algo, :implem, :nb_states])

Expand All @@ -13,7 +13,7 @@ implems = [
"pomegranate", #
"dynamax", #
]
algos = ["forward", "baum_welch"]
algos = ["viterbi", "forward", "forward_backward", "baum_welch"]

markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle]

Expand All @@ -24,6 +24,7 @@ for algo in algos
yscale=:log,
xlabel="nb states",
ylabel="runtime (s)",
xticks=unique(data[!, :nb_states]),
legend=:outerright,
margin=5Plots.mm,
)
Expand All @@ -33,10 +34,6 @@ for algo in algos
pl,
subdata[!, :nb_states],
subdata[!, :time_median] ./ 1e9;
yerror=(
(subdata[!, :time_median] .- subdata[!, :time_quantile25]) ./ 1e9,
(subdata[!, :time_quantile75] .- subdata[!, :time_median]) ./ 1e9,
),
label=implem,
markershape=markershapes[i],
markerstrokecolor=:auto,
Expand All @@ -46,5 +43,5 @@ for algo in algos
)
end
display(pl)
savefig(pl, joinpath(@__DIR__, "$(algo).png"))
savefig(pl, joinpath(@__DIR__, "results", "$(algo).pdf"))
end
Empty file.
Loading

0 comments on commit 5ad5d81

Please sign in to comment.