diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml deleted file mode 100644 index a3a0d94e..00000000 --- a/.github/workflows/draft-pdf.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: JOSS paper compilation -on: - push: - branches: - - joss* -jobs: - paper: - runs-on: ubuntu-latest - name: Paper Draft - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Build draft PDF - uses: openjournals/openjournals-draft-action@master - with: - journal: joss - # This should be the path to the paper within your repo. - paper-path: paper/paper.md - - name: Upload - uses: actions/upload-artifact@v1 - with: - name: paper - # This is the output path where Pandoc will write the compiled - # PDF. Note, this should be the same directory as the input - # paper.md - path: paper/paper.pdf \ No newline at end of file diff --git a/.gitignore b/.gitignore index f8e41336..954eac3b 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,4 @@ scratchpad.jl /docs/src/index.md /docs/src/examples/*.md -*.pdf -*.png \ No newline at end of file +*.pdf \ No newline at end of file diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index ca459dfe..8e19bf18 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,9 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0" +julia_version = "1.10.1" manifest_format = "2.0" project_hash = "e05ed926575e94b72904ad898b09f017dc14d96a" +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -34,9 +39,9 @@ version = "0.5.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "1287e3872d646eed95198457873249bd9f0caed2" +git-tree-sha1 = "892b245fdec1c511906671b6a5e1bafa38a727c1" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.20.1" +version = "1.22.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -50,9 +55,9 @@ version = "0.7.4" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b" +git-tree-sha1 = "d2c021fbdde94f6cdaa799639adfeeaa17fd67f5" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.12.0" +version = "4.13.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -61,7 +66,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+1" +version = "1.1.0+0" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -81,9 +86,9 @@ version = "1.6.1" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed" +git-tree-sha1 = "1fb174f0d48fe7d142e1109a10636bc1d14f5ac2" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.16" +version = "0.18.17" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -155,13 +160,13 @@ deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.HMMBenchmark]] -deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"] +deps = ["BenchmarkTools", "CSV", "DataFrames", "Distributions", "HiddenMarkovModels", "InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "SparseArrays", "StableRNGs", "Statistics"] path = "../libs/HMMBenchmark" uuid = "557005d5-2e4a-43f9-8aa7-ba8df2d03179" version = "0.1.0" [[deps.HiddenMarkovModels]] -deps = ["ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI"] +deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"] path = ".." uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" version = "0.4.0" @@ -257,9 +262,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.26" +version = "0.3.27" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -309,7 +314,7 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+2" +version = "0.3.23+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -469,9 +474,9 @@ version = "0.34.2" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" +git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.0" +version = "1.3.1" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] diff --git a/libs/HMMBenchmark/Project.toml b/libs/HMMBenchmark/Project.toml index c9a9a128..48ccdc9c 100644 --- a/libs/HMMBenchmark/Project.toml +++ b/libs/HMMBenchmark/Project.toml @@ -9,6 +9,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/libs/HMMBenchmark/src/HMMBenchmark.jl b/libs/HMMBenchmark/src/HMMBenchmark.jl index 6bf86a9c..542ff6cd 100644 --- a/libs/HMMBenchmark/src/HMMBenchmark.jl +++ b/libs/HMMBenchmark/src/HMMBenchmark.jl @@ -1,6 +1,7 @@ module HMMBenchmark -using Base.Threads +using Base.Threads: Threads +using InteractiveUtils: InteractiveUtils using BenchmarkTools: @benchmarkable, BenchmarkGroup using CSV: CSV using DataFrames: DataFrame @@ -33,5 +34,6 @@ include("instance.jl") include("params.jl") include("hiddenmarkovmodels.jl") include("suite.jl") +include("setup.jl") end diff --git a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 76fab3fd..63ca122b 100644 --- a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -5,11 +5,11 @@ function build_model(::HiddenMarkovModelsImplem, instance::Instance, params::Par (; custom_dist, nb_states, obs_dim) = instance (; init, trans, means, stds) = params - if custom_dist - dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states] + if obs_dim == 1 + dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states] else - if obs_dim == 1 - dists = [Normal(means[1, i], stds[1, i]) for i in 1:nb_states] + if custom_dist + dists = [LightDiagNormal(means[:, i], stds[:, i]) for i in 1:nb_states] else dists = [MvNormal(means[:, i], Diagonal(stds[:, i])) for i in 1:nb_states] end @@ -43,10 +43,12 @@ function build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 10 + end + if "forward!" in algos benchs["forward!"] = @benchmarkable begin forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -54,10 +56,12 @@ function build_benchmarkables( if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 10 + end + if "viterbi!" in algos benchs["viterbi!"] = @benchmarkable begin viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -65,10 +69,12 @@ function build_benchmarkables( if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 + end evals = 1 samples = 10 + end + if "forward_backward!" in algos benchs["forward_backward!"] = @benchmarkable begin forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( fb_storage = initialize_forward_backward( $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) @@ -86,7 +92,9 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 100 + end evals = 1 samples = 10 + end + if "baum_welch!" in algos benchs["baum_welch!"] = @benchmarkable begin baum_welch!( fb_storage, @@ -99,7 +107,7 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( hmm_guess = build_model($implem, $instance, $params); fb_storage = initialize_forward_backward( hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends diff --git a/libs/HMMBenchmark/src/setup.jl b/libs/HMMBenchmark/src/setup.jl index 83d76fe4..811939b8 100644 --- a/libs/HMMBenchmark/src/setup.jl +++ b/libs/HMMBenchmark/src/setup.jl @@ -1,7 +1,7 @@ -function print_julia_setup(; path) +function print_julia_setup(path) open(path, "w") do file redirect_stdout(file) do - versioninfo() + InteractiveUtils.versioninfo() println("\n# Multithreading\n") println("Julia threads = $(Threads.nthreads())") println("OpenBLAS threads = $(BLAS.get_num_threads())") diff --git a/libs/HMMComparison/experiments/performance.jl b/libs/HMMComparison/experiments/performance.jl new file mode 100644 index 00000000..fe6ed1ba --- /dev/null +++ b/libs/HMMComparison/experiments/performance.jl @@ -0,0 +1,62 @@ +using Pkg +Pkg.activate(joinpath(@__DIR__, "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "HMMBenchmark")) + +@assert Base.Threads.nthreads() == 1 + +# see https://superfastpython.com/numpy-number-blas-threads/ +ENV["MKL_NUM_THREADS"] = 1 +ENV["NUMEXPR_NUM_THREADS"] = 1 +ENV["OMP_NUM_THREADS"] = 1 +ENV["OPENBLAS_NUM_THREADS"] = 1 +ENV["VECLIB_MAXIMUM_THREADS"] = 1 + +# see https://github.com/google/jax/issues/743 +ENV["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" + +using BenchmarkTools +using LinearAlgebra +using PythonCall # Python process starts now +using StableRNGs +using HMMComparison + +# see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html +pyimport("torch").set_num_threads(1) + +rng = StableRNG(63) + +print_julia_setup(joinpath(@__DIR__, "results", "julia_setup.txt")) +print_python_setup(joinpath(@__DIR__, "results", "python_setup.txt")) + +implems = [ + HiddenMarkovModelsImplem(), # + HMMBaseImplem(), # + hmmlearnImplem(), # + pomegranateImplem(), # + dynamaxImplem(), # +] + +algos = ["forward", "viterbi", "forward_backward", "baum_welch"] + +instances = Instance[] + +for nb_states in 2:2:16 + push!( + instances, + Instance(; + custom_dist=true, + sparse=false, + nb_states=nb_states, + obs_dim=1, + seq_length=200, + nb_seqs=100, + bw_iter=5, + ), + ) +end + +SUITE = define_suite(rng, implems; instances, algos) + +results = BenchmarkTools.run(SUITE; verbose=true) +data = parse_results(results; path=joinpath(@__DIR__, "results", "results.csv")) diff --git a/libs/HMMComparison/test/plots.jl b/libs/HMMComparison/experiments/plots.jl similarity index 74% rename from libs/HMMComparison/test/plots.jl rename to libs/HMMComparison/experiments/plots.jl index 58371848..f922c15f 100644 --- a/libs/HMMComparison/test/plots.jl +++ b/libs/HMMComparison/experiments/plots.jl @@ -2,7 +2,7 @@ using DataFrames using Plots using HMMComparison -data = read_results(joinpath(@__DIR__, "results.csv")) +data = read_results(joinpath(@__DIR__, "results", "results.csv")) sort!(data, [:algo, :implem, :nb_states]) @@ -13,7 +13,7 @@ implems = [ "pomegranate", # "dynamax", # ] -algos = ["forward", "baum_welch"] +algos = ["viterbi", "forward", "forward_backward", "baum_welch"] markershapes = [:star5, :circle, :diamond, :hexagon, :pentagon, :utriangle] @@ -24,6 +24,7 @@ for algo in algos yscale=:log, xlabel="nb states", ylabel="runtime (s)", + xticks=unique(data[!, :nb_states]), legend=:outerright, margin=5Plots.mm, ) @@ -33,10 +34,6 @@ for algo in algos pl, subdata[!, :nb_states], subdata[!, :time_median] ./ 1e9; - yerror=( - (subdata[!, :time_median] .- subdata[!, :time_quantile25]) ./ 1e9, - (subdata[!, :time_quantile75] .- subdata[!, :time_median]) ./ 1e9, - ), label=implem, markershape=markershapes[i], markerstrokecolor=:auto, @@ -46,5 +43,5 @@ for algo in algos ) end display(pl) - savefig(pl, joinpath(@__DIR__, "$(algo).png")) + savefig(pl, joinpath(@__DIR__, "results", "$(algo).pdf")) end diff --git a/libs/HMMComparison/experiments/results/.gitkeep b/libs/HMMComparison/experiments/results/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index faecdca3..d60e8ef3 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -47,8 +47,8 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin - $(filter_vmap)($dyn_params, $obs_tens_jax_py).block_until_ready() - end evals = 1 samples = 100 + $(filter_vmap)($dyn_params, $obs_tens_jax_py) + end evals = 1 samples = 10 end if "viterbi" in algos @@ -56,17 +56,17 @@ function HMMBenchmark.build_benchmarkables( jax.vmap(hmm.most_likely_states; in_axes=pylist((pybuiltins.None, 0))) ) benchs["viterbi"] = @benchmarkable begin - $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py).block_until_ready() - end evals = 1 samples = 100 + $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py) + end evals = 1 samples = 10 end if "forward_backward" in algos smoother_vmap = jax.jit( - jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))).block_until_ready() + jax.vmap(hmm.smoother; in_axes=pylist((pybuiltins.None, 0))) ) benchs["forward_backward"] = @benchmarkable begin $(smoother_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "baum_welch" in algos @@ -77,8 +77,8 @@ function HMMBenchmark.build_benchmarkables( $obs_tens_jax_py; num_iters=$bw_iter, verbose=false, - ).block_until_ready() - end evals = 1 samples = 100 setup = ( + ) + end evals = 1 samples = 10 setup = ( tup = build_model($implem, $instance, $params); hmm_guess = tup[1]; dyn_params_guess = tup[2]; diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index 05881746..2c260e89 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -24,7 +24,7 @@ function HMMBenchmark.build_benchmarkables( data::AbstractArray{<:Real,3}, algos::Vector{String}, ) - (; obs_dim, seq_length, nb_seqs, bw_iter) = instance + (; obs_dim, nb_seqs, bw_iter) = instance hmm = build_model(implem, instance, params) if obs_dim == 1 @@ -38,34 +38,32 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin - @threads for k in eachindex(obs_mats) - HMMBase.forward($hmm, $(obs_mats[k])) + @threads for k in eachindex($obs_mats) + HMMBase.forward($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin - @threads for k in eachindex(obs_mats) - HMMBase.viterbi($hmm, $(obs_mats[k])) + @threads for k in eachindex($obs_mats) + HMMBase.viterbi($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin - @threads for k in eachindex(obs_mats) - HMMBase.posteriors($hmm, $(obs_mats[k])) + @threads for k in eachindex($obs_mats) + HMMBase.posteriors($hmm, $obs_mats[k]) end - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin - @threads for k in eachindex(obs_mats) - HMMBase.fit_mle($hmm, $(obs_mats[k]); maxiter=$bw_iter, tol=-Inf) - end - end evals = 1 samples = 100 + HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf) + end evals = 1 samples = 10 end return benchs diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index 610c9129..58b1f053 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -45,25 +45,25 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin $(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index d5939919..adf27f53 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -57,19 +57,19 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.forward)($obs_tens_torch_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.forward_backward)($obs_tens_torch_py) - end evals = 1 samples = 100 + end evals = 1 samples = 10 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_tens_torch_py) - end evals = 1 samples = 100 setup = ( + end evals = 1 samples = 10 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/src/setup.jl b/libs/HMMComparison/src/setup.jl index c311e101..0a74bb34 100644 --- a/libs/HMMComparison/src/setup.jl +++ b/libs/HMMComparison/src/setup.jl @@ -1,8 +1,8 @@ -function print_python_setup(; path) +function print_python_setup(path) open(path, "w") do file redirect_stdout(file) do - println("Pytorch threads = $(torch.get_num_threads())") + println("Pytorch threads = $(pyimport("torch").get_num_threads())") println("\n# Python packages\n") end redirect_stderr(file) do diff --git a/libs/HMMComparison/test/performance.jl b/libs/HMMComparison/test/performance.jl deleted file mode 100644 index cd6643cf..00000000 --- a/libs/HMMComparison/test/performance.jl +++ /dev/null @@ -1,40 +0,0 @@ -using BenchmarkTools -using HMMComparison -using LinearAlgebra -using StableRNGs - -BLAS.set_num_threads(1) - -rng = StableRNG(63) - -implems = [ - HiddenMarkovModelsImplem(), # - HMMBaseImplem(), # - hmmlearnImplem(), # - pomegranateImplem(), # - dynamaxImplem(), # -] - -algos = ["forward", "viterbi", "forward_backward", "baum_welch"] - -instances = Instance[] - -for nb_states in 2:3:24 - push!( - instances, - Instance(; - custom_dist=true, - sparse=false, - nb_states=nb_states, - obs_dim=5, - seq_length=100, - nb_seqs=10, - bw_iter=10, - ), - ) -end - -SUITE = define_suite(rng, implems; instances, algos) - -results = BenchmarkTools.run(SUITE; verbose=true) -data = parse_results(results; path=joinpath(@__DIR__, "results.csv")) diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index 17627bea..f16f46dc 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -47,7 +47,9 @@ function DensityInterface.logdensityof(dist::LightCategorical, k::Integer) return dist.logp[k] end -function StatsAPI.fit!(dist::LightCategorical{T1}, x, w) where {T1} +function StatsAPI.fit!( + dist::LightCategorical{T1}, x::AbstractVector{<:Integer}, w::AbstractVector +) where {T1} @argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p) w_tot = sum(w) dist.p .= zero(T1) diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 17d114cb..8b0748d6 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -41,7 +41,9 @@ function Base.rand(rng::AbstractRNG, dist::LightDiagNormal{T1,T2}) where {T1,T2} return dist.σ .* randn(rng, T, length(dist)) .+ dist.μ end -function DensityInterface.logdensityof(dist::LightDiagNormal{T1,T2,T3}, x) where {T1,T2,T3} +function DensityInterface.logdensityof( + dist::LightDiagNormal{T1,T2,T3}, x::AbstractVector +) where {T1,T2,T3} l = zero(promote_type(T1, T2, T3, eltype(x))) l -= sum(dist.logσ) + log2π * length(x) / 2 @inbounds @simd for i in eachindex(x, dist.μ, dist.σ) @@ -50,7 +52,9 @@ function DensityInterface.logdensityof(dist::LightDiagNormal{T1,T2,T3}, x) where return l end -function StatsAPI.fit!(dist::LightDiagNormal{T1,T2}, x, w) where {T1,T2} +function StatsAPI.fit!( + dist::LightDiagNormal{T1,T2}, x::AbstractVector{<:AbstractVector}, w::AbstractVector +) where {T1,T2} w_tot = sum(w) dist.μ .= zero(T1) dist.σ .= zero(T2)