Skip to content

Commit

Permalink
Add progress meter for Surrogate and LocalPermutation tests (#343)
Browse files Browse the repository at this point in the history
* add progress meter as dependency

* add progress meter to surrogate test

* drastically simplify the SurrogateTest source

Length check should occur at the individual functions, not in the independence test!

TODO: the whole `independence` function can be simplified throughout the code base , but that's for another PR @kahaaga

* more informative error for independence

* add progress meter to local permutation

* make progress false by default

* bump version

* Add progress meter to transfer-entropy specific implementation too

---------

Co-authored-by: Kristian Agasøster Haaga <kahaaga@gmail.com>
  • Loading branch information
Datseris and kahaaga authored Oct 4, 2023
1 parent 477bd4c commit 2c6fc3b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 55 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "CausalityTools"
uuid = "5520caf5-2dd7-5c5d-bfcb-a00e56ac49f7"
authors = ["Kristian Agasøster Haaga <kahaaga@gmail.com>", "Tor Einar Møller <temolle@gmail.com>", "George Datseris <datseris.george@gmail.com>"]
repo = "https://github.com/kahaaga/CausalityTools.jl.git"
version = "2.9.2"
version = "2.10.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -19,6 +19,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Neighborhood = "645ca80c-8b79-4109-87ea-e1f58159d116"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecurrenceAnalysis = "639c3291-70d9-5ea2-8c5b-839eba1ee399"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -46,6 +47,7 @@ HypothesisTests = "0.8, 1, 0.10, 0.11"
LabelledArrays = "1.6.7"
NearestNeighbors = "0.4"
Neighborhood = "0.2.2"
ProgressMeter = "1.7"
RecurrenceAnalysis = "2"
Reexport = "0.2, 1"
Scratch = "1"
Expand Down
7 changes: 6 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 2.10

- Progress bars in some independence tests (surrogate, local permutation) can be
enabled by passing keyword `show_progress = true` in the test constructors.

## 2.9

### Bug fixes
Expand Down Expand Up @@ -65,7 +70,7 @@ indices).

## 2.2

- Added `MCR` and `RMCD` recurrence based association measures, along with
- Added `MCR` and `RMCD` recurrence based association measures, along with
the corresponding `mcr` and `rmcd` methods.

## 2.1
Expand Down
3 changes: 2 additions & 1 deletion src/independence_tests/independence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ Returns a test `summary`, whose type depends on `test`.
- [`JointDistanceDistributionTest`](@ref).
"""
function independence(test::IndependenceTest, x...)
throw(ArgumentError("No concrete implementation for $(typeof(test)) test yet"))
L = length(x)
throw(ArgumentError("No concrete implementation for $(typeof(test)) test with $(L) input variables."))
end

function pvalue_text_summary(test::IndependenceTestResult)
Expand Down
30 changes: 18 additions & 12 deletions src/independence_tests/local_permutation/LocalPermutationTest.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Random: shuffle!
using Random
import Statistics: quantile
import ProgressMeter

export LocalPermutationTest
export LocalPermutationTestResult
Expand Down Expand Up @@ -34,7 +35,8 @@ struct NeighborCloseness <: LocalPermutationClosenessSearch end
nshuffles::Int = 100,
rng = Random.default_rng(),
replace = true,
w::Int = 0)
w::Int = 0,
show_progress = false)
`LocalPermutationTest` is a generic conditional independence test
[Runge2018LocalPerm](@cite) for assessing whether two variables `X` and `Y` are
Expand Down Expand Up @@ -102,15 +104,16 @@ struct LocalPermutationTest{M, EST, C, R} <: IndependenceTest{M}
replace::Bool
closeness_search::C
w::Int # Theiler window
function LocalPermutationTest(measure::M, est::EST = nothing;
rng::R = Random.default_rng(),
kperm::Int = 10,
replace::Bool = true,
nshuffles::Int = 100,
closeness_search::C = NeighborCloseness(),
w::Int = 0) where {M, EST, C, R}
new{M, EST, C, R}(measure, est, rng, kperm, nshuffles, replace, closeness_search, w)
end
show_progress::Bool
end
function LocalPermutationTest(measure::M, est::EST = nothing;
rng::R = Random.default_rng(),
kperm::Int = 10,
replace::Bool = true,
nshuffles::Int = 100,
closeness_search::C = NeighborCloseness(),
w::Int = 0, show_progress = false) where {M, EST, C, R}
return LocalPermutationTest{M, EST, C, R}(measure, est, rng, kperm, nshuffles, replace, closeness_search, w, show_progress)
end

Base.show(io::IO, test::LocalPermutationTest) = print(io,
Expand Down Expand Up @@ -165,7 +168,6 @@ function independence(test::LocalPermutationTest, x, y, z)

X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z)
@assert length(X) == length(Y) == length(Z)
N = length(X)
= estimate(measure, est, X, Y, Z)
Îs = permuted_Îs(X, Y, Z, measure, est, test)
p = count(Î .<= Îs) / nshuffles
Expand All @@ -177,7 +179,10 @@ end
# computing the test statistic.
function permuted_Îs(X, Y, Z, measure, est, test)
rng, kperm, nshuffles, replace, w = test.rng, test.kperm, test.nshuffles, test.replace, test.w

progress = ProgressMeter.Progress(nshuffles;
desc = "LocalPermutationTest:",
enabled = test.show_progress
)
N = length(X)
test.kperm < N || throw(ArgumentError("kperm must be smaller than input data length"))

Expand All @@ -194,6 +199,7 @@ function permuted_Îs(X, Y, Z, measure, est, test)
shuffle_without_replacement!(X̂, X, idxs_z, kperm, rng, Nᵢ, πs)
end
Îs[n] = estimate(measure, est, X̂, Y, Z)
ProgressMeter.next!(progress)
end

return Îs
Expand Down
6 changes: 5 additions & 1 deletion src/independence_tests/local_permutation/transferentropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ end
# about the target variable is left untouched.
function permuted_Îs_te(S, T, T⁺, C, measure::TransferEntropy, est, test)
rng, kperm, nshuffles, replace, w = test.rng, test.kperm, test.nshuffles, test.replace, test.w

progress = ProgressMeter.Progress(nshuffles;
desc = "LocalPermutationTest:",
enabled = test.show_progress
)
N = length(S)
test.kperm < N || throw(ArgumentError("kperm must be smaller than input data length"))

Expand All @@ -65,6 +68,7 @@ function permuted_Îs_te(S, T, T⁺, C, measure::TransferEntropy, est, test)
shuffle_without_replacement!(Ŝ, S, idxs_C, kperm, rng, Nᵢ, πs)
end
Îs[n] = estimate(measure, est, Ŝ, T, T⁺, C)
ProgressMeter.next!(progress)
end
return Îs
end
62 changes: 23 additions & 39 deletions src/independence_tests/surrogate/SurrogateTest.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Random
using TimeseriesSurrogates
import ProgressMeter
export SurrogateTest
export SurrogateTestResult

Expand All @@ -9,6 +10,7 @@ export SurrogateTestResult
nshuffles::Int = 100,
surrogate = RandomShuffle(),
rng = Random.default_rng(),
show_progress = false,
)
A generic (conditional) independence test for assessing whether two variables `X` and `Y`
Expand Down Expand Up @@ -72,14 +74,14 @@ struct SurrogateTest{M, E, R, S} <: IndependenceTest{M}
rng::R
surrogate::S
nshuffles::Int

function SurrogateTest(measure::M, est::E = nothing;
rng::R = Random.default_rng(),
surrogate::S = RandomShuffle(),
nshuffles::Int = 100,
) where {M, E, R, S}
new{M, E, R, S}(measure, est, rng, surrogate, nshuffles)
end
show_progress::Bool
end
function SurrogateTest(measure::M, est::E = nothing;
rng::R = Random.default_rng(),
surrogate::S = RandomShuffle(),
nshuffles::Int = 100, show_progress = false
) where {M, E, R, S}
SurrogateTest{M, E, R, S}(measure, est, rng, surrogate, nshuffles, show_progress)
end


Expand Down Expand Up @@ -127,46 +129,28 @@ end
# Generic dispatch for any three-argument conditional independence measure where the
# third argument is to be conditioned on. This works naturally with e.g.
# conditional mutual information.
function independence(test::SurrogateTest, x, y, z)
(; measure, est, rng, surrogate, nshuffles) = test

# Make sure that the measure is compatible with the input data.
verify_number_of_inputs_vars(measure, 3)
function independence(test::SurrogateTest, x, args...)
# Setup (`args...` is either `y` or `y, z`)
(; measure, est, rng, surrogate, nshuffles, show_progress) = test
verify_number_of_inputs_vars(measure, 1+length(args))
SSSets = map(w -> StateSpaceSet(w), args)
estimation = x -> estimate(measure, est, x, SSSets...)
progress = ProgressMeter.Progress(nshuffles;
desc="SurrogateTest:", enabled=show_progress
)

X, Y, Z = StateSpaceSet(x), StateSpaceSet(y), StateSpaceSet(z)
@assert length(X) == length(Y) == length(Z)
N = length(x)
= estimate(measure,est, X, Y, Z)
# Estimate
= estimation(StateSpaceSet(x))
s = surrogenerator(x, surrogate, rng)
Îs = zeros(nshuffles)
for b in 1:nshuffles
Îs[b] = estimate(measure, est, s(), Y, Z)
Îs[b] = estimation(s())
ProgressMeter.next!(progress)
end
p = count(Î .<= Îs) / nshuffles

return SurrogateTestResult(3, Î, Îs, p, nshuffles)
end

function independence(test::SurrogateTest, x, y)
(; measure, est, rng, surrogate, nshuffles) = test

# Make sure that the measure is compatible with the input data.
verify_number_of_inputs_vars(measure, 2)

X, Y = StateSpaceSet(x), StateSpaceSet(y)
@assert length(X) == length(Y)
N = length(x)
= estimate(measure,est, X, Y)
sx = surrogenerator(x, surrogate, rng)
Îs = zeros(nshuffles)
for b in 1:nshuffles
Îs[b] = estimate(measure, est, sx(), y)
end
p = count(Î .<= Îs) / nshuffles

return SurrogateTestResult(2, Î, Îs, p, nshuffles)
end

# Concrete implementations
include("contingency.jl")
include("transferentropy.jl")
Expand Down

2 comments on commit 2c6fc3b

@kahaaga
Copy link
Member

@kahaaga kahaaga commented on 2c6fc3b Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92766

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.10.0 -m "<description of version>" 2c6fc3b2f53d3ae72e968e1142378d05db8a27b6
git push origin v2.10.0

Please sign in to comment.