From 2c6fc3b2f53d3ae72e968e1142378d05db8a27b6 Mon Sep 17 00:00:00 2001 From: George Datseris Date: Wed, 4 Oct 2023 18:26:34 +0100 Subject: [PATCH] Add progress meter for Surrogate and LocalPermutation tests (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add progress meter as dependency * add progress meter to surrogate test * drastically simplify the SurrogateTest source Length check should occur at the individual functions, not in the independence test! TODO: the whole `independence` function can be simplified throughout the code base , but that's for another PR @kahaaga * more informative error for independence * add progress meter to local permutation * make progress false by default * bump version * Add progress meter to transfer-entropy specific implementation too --------- Co-authored-by: Kristian Agasøster Haaga --- Project.toml | 4 +- changelog.md | 7 ++- src/independence_tests/independence.jl | 3 +- .../local_permutation/LocalPermutationTest.jl | 30 +++++---- .../local_permutation/transferentropy.jl | 6 +- .../surrogate/SurrogateTest.jl | 62 +++++++------------ 6 files changed, 57 insertions(+), 55 deletions(-) diff --git a/Project.toml b/Project.toml index af6d21f43..965cd3111 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "CausalityTools" uuid = "5520caf5-2dd7-5c5d-bfcb-a00e56ac49f7" authors = ["Kristian Agasøster Haaga ", "Tor Einar Møller ", "George Datseris "] repo = "https://github.com/kahaaga/CausalityTools.jl.git" -version = "2.9.2" +version = "2.10.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -19,6 +19,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Neighborhood = "645ca80c-8b79-4109-87ea-e1f58159d116" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecurrenceAnalysis = "639c3291-70d9-5ea2-8c5b-839eba1ee399" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -46,6 +47,7 @@ HypothesisTests = "0.8, 1, 0.10, 0.11" LabelledArrays = "1.6.7" NearestNeighbors = "0.4" Neighborhood = "0.2.2" +ProgressMeter = "1.7" RecurrenceAnalysis = "2" Reexport = "0.2, 1" Scratch = "1" diff --git a/changelog.md b/changelog.md index ab90ca926..8952ba547 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,10 @@ # Changelog +## 2.10 + +- Progress bars in some independence tests (surrogate, local permutation) can be + enabled by passing keyword `show_progress = true` in the test constructors. + ## 2.9 ### Bug fixes @@ -65,7 +70,7 @@ indices). ## 2.2 -- Added `MCR` and `RMCD` recurrence based association measures, along with +- Added `MCR` and `RMCD` recurrence based association measures, along with the corresponding `mcr` and `rmcd` methods. ## 2.1 diff --git a/src/independence_tests/independence.jl b/src/independence_tests/independence.jl index 653a8713f..6d3face39 100644 --- a/src/independence_tests/independence.jl +++ b/src/independence_tests/independence.jl @@ -28,7 +28,8 @@ Returns a test `summary`, whose type depends on `test`. - [`JointDistanceDistributionTest`](@ref). """ function independence(test::IndependenceTest, x...) - throw(ArgumentError("No concrete implementation for $(typeof(test)) test yet")) + L = length(x) + throw(ArgumentError("No concrete implementation for $(typeof(test)) test with $(L) input variables.")) end function pvalue_text_summary(test::IndependenceTestResult) diff --git a/src/independence_tests/local_permutation/LocalPermutationTest.jl b/src/independence_tests/local_permutation/LocalPermutationTest.jl index ceba58784..6f193c06a 100644 --- a/src/independence_tests/local_permutation/LocalPermutationTest.jl +++ b/src/independence_tests/local_permutation/LocalPermutationTest.jl @@ -1,6 +1,7 @@ using Random: shuffle! using Random import Statistics: quantile +import ProgressMeter export LocalPermutationTest export LocalPermutationTestResult @@ -34,7 +35,8 @@ struct NeighborCloseness <: LocalPermutationClosenessSearch end nshuffles::Int = 100, rng = Random.default_rng(), replace = true, - w::Int = 0) + w::Int = 0, + show_progress = false) `LocalPermutationTest` is a generic conditional independence test [Runge2018LocalPerm](@cite) for assessing whether two variables `X` and `Y` are @@ -102,15 +104,16 @@ struct LocalPermutationTest{M, EST, C, R} <: IndependenceTest{M} replace::Bool closeness_search::C w::Int # Theiler window - function LocalPermutationTest(measure::M, est::EST = nothing; - rng::R = Random.default_rng(), - kperm::Int = 10, - replace::Bool = true, - nshuffles::Int = 100, - closeness_search::C = NeighborCloseness(), - w::Int = 0) where {M, EST, C, R} - new{M, EST, C, R}(measure, est, rng, kperm, nshuffles, replace, closeness_search, w) - end + show_progress::Bool +end +function LocalPermutationTest(measure::M, est::EST = nothing; + rng::R = Random.default_rng(), + kperm::Int = 10, + replace::Bool = true, + nshuffles::Int = 100, + closeness_search::C = NeighborCloseness(), + w::Int = 0, show_progress = false) where {M, EST, C, R} + return LocalPermutationTest{M, EST, C, R}(measure, est, rng, kperm, nshuffles, replace, closeness_search, w, show_progress) end Base.show(io::IO, test::LocalPermutationTest) = print(io, @@ -165,7 +168,6 @@ function independence(test::LocalPermutationTest, x, y, z) X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z) @assert length(X) == length(Y) == length(Z) - N = length(X) Î = estimate(measure, est, X, Y, Z) Îs = permuted_Îs(X, Y, Z, measure, est, test) p = count(Î .<= Îs) / nshuffles @@ -177,7 +179,10 @@ end # computing the test statistic. function permuted_Îs(X, Y, Z, measure, est, test) rng, kperm, nshuffles, replace, w = test.rng, test.kperm, test.nshuffles, test.replace, test.w - + progress = ProgressMeter.Progress(nshuffles; + desc = "LocalPermutationTest:", + enabled = test.show_progress + ) N = length(X) test.kperm < N || throw(ArgumentError("kperm must be smaller than input data length")) @@ -194,6 +199,7 @@ function permuted_Îs(X, Y, Z, measure, est, test) shuffle_without_replacement!(X̂, X, idxs_z, kperm, rng, Nᵢ, πs) end Îs[n] = estimate(measure, est, X̂, Y, Z) + ProgressMeter.next!(progress) end return Îs diff --git a/src/independence_tests/local_permutation/transferentropy.jl b/src/independence_tests/local_permutation/transferentropy.jl index b729694c1..ad6c02a66 100644 --- a/src/independence_tests/local_permutation/transferentropy.jl +++ b/src/independence_tests/local_permutation/transferentropy.jl @@ -45,7 +45,10 @@ end # about the target variable is left untouched. function permuted_Îs_te(S, T, T⁺, C, measure::TransferEntropy, est, test) rng, kperm, nshuffles, replace, w = test.rng, test.kperm, test.nshuffles, test.replace, test.w - + progress = ProgressMeter.Progress(nshuffles; + desc = "LocalPermutationTest:", + enabled = test.show_progress + ) N = length(S) test.kperm < N || throw(ArgumentError("kperm must be smaller than input data length")) @@ -65,6 +68,7 @@ function permuted_Îs_te(S, T, T⁺, C, measure::TransferEntropy, est, test) shuffle_without_replacement!(Ŝ, S, idxs_C, kperm, rng, Nᵢ, πs) end Îs[n] = estimate(measure, est, Ŝ, T, T⁺, C) + ProgressMeter.next!(progress) end return Îs end diff --git a/src/independence_tests/surrogate/SurrogateTest.jl b/src/independence_tests/surrogate/SurrogateTest.jl index 65e347ad9..4b325f458 100644 --- a/src/independence_tests/surrogate/SurrogateTest.jl +++ b/src/independence_tests/surrogate/SurrogateTest.jl @@ -1,5 +1,6 @@ using Random using TimeseriesSurrogates +import ProgressMeter export SurrogateTest export SurrogateTestResult @@ -9,6 +10,7 @@ export SurrogateTestResult nshuffles::Int = 100, surrogate = RandomShuffle(), rng = Random.default_rng(), + show_progress = false, ) A generic (conditional) independence test for assessing whether two variables `X` and `Y` @@ -72,14 +74,14 @@ struct SurrogateTest{M, E, R, S} <: IndependenceTest{M} rng::R surrogate::S nshuffles::Int - - function SurrogateTest(measure::M, est::E = nothing; - rng::R = Random.default_rng(), - surrogate::S = RandomShuffle(), - nshuffles::Int = 100, - ) where {M, E, R, S} - new{M, E, R, S}(measure, est, rng, surrogate, nshuffles) - end + show_progress::Bool +end +function SurrogateTest(measure::M, est::E = nothing; + rng::R = Random.default_rng(), + surrogate::S = RandomShuffle(), + nshuffles::Int = 100, show_progress = false + ) where {M, E, R, S} + SurrogateTest{M, E, R, S}(measure, est, rng, surrogate, nshuffles, show_progress) end @@ -127,46 +129,28 @@ end # Generic dispatch for any three-argument conditional independence measure where the # third argument is to be conditioned on. This works naturally with e.g. # conditional mutual information. -function independence(test::SurrogateTest, x, y, z) - (; measure, est, rng, surrogate, nshuffles) = test - - # Make sure that the measure is compatible with the input data. - verify_number_of_inputs_vars(measure, 3) +function independence(test::SurrogateTest, x, args...) + # Setup (`args...` is either `y` or `y, z`) + (; measure, est, rng, surrogate, nshuffles, show_progress) = test + verify_number_of_inputs_vars(measure, 1+length(args)) + SSSets = map(w -> StateSpaceSet(w), args) + estimation = x -> estimate(measure, est, x, SSSets...) + progress = ProgressMeter.Progress(nshuffles; + desc="SurrogateTest:", enabled=show_progress + ) - X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z) - @assert length(X) == length(Y) == length(Z) - N = length(x) - Î = estimate(measure,est, X, Y, Z) + # Estimate + Î = estimation(StateSpaceSet(x)) s = surrogenerator(x, surrogate, rng) Îs = zeros(nshuffles) for b in 1:nshuffles - Îs[b] = estimate(measure, est, s(), Y, Z) + Îs[b] = estimation(s()) + ProgressMeter.next!(progress) end p = count(Î .<= Îs) / nshuffles - return SurrogateTestResult(3, Î, Îs, p, nshuffles) end -function independence(test::SurrogateTest, x, y) - (; measure, est, rng, surrogate, nshuffles) = test - - # Make sure that the measure is compatible with the input data. - verify_number_of_inputs_vars(measure, 2) - - X, Y = StateSpaceSet(x), StateSpaceSet(y) - @assert length(X) == length(Y) - N = length(x) - Î = estimate(measure,est, X, Y) - sx = surrogenerator(x, surrogate, rng) - Îs = zeros(nshuffles) - for b in 1:nshuffles - Îs[b] = estimate(measure, est, sx(), y) - end - p = count(Î .<= Îs) / nshuffles - - return SurrogateTestResult(2, Î, Îs, p, nshuffles) -end - # Concrete implementations include("contingency.jl") include("transferentropy.jl")