diff --git a/Project.toml b/Project.toml index 76da5a77e..eedd177a9 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "CausalityTools" uuid = "5520caf5-2dd7-5c5d-bfcb-a00e56ac49f7" authors = ["Kristian Agasøster Haaga ", "Tor Einar Møller ", "George Datseris "] repo = "https://github.com/kahaaga/CausalityTools.jl.git" -version = "2.2.1" +version = "2.3.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -36,7 +36,7 @@ Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a" [compat] Accessors = "^0.1.28" Combinatorics = "1" -ComplexityMeasures = "^2.6" # contains important bugfix +ComplexityMeasures = "^2.6" DSP = "0.7" DelayEmbeddings = "2.6" Distances = "^0.10" diff --git a/changelog.md b/changelog.md index 9356864ca..cb292f098 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,17 @@ # Changelog +## 2.3.0 + +- Significant speed-ups for `OCE` by sorting on maximal measure, thus avoiding + unnecessary significance tests. +- Default parameters for `OCE` default lag parameter have changed. Now, `τmax = 1`, since + that is the only case considered in the original paper. We also use the + `MesnerShalisi` CMI estimator for the conditional step, because in contrast to + the `FPVP` estimator, it has been shown to be consistent. +- Source code for `OCE` has been drastically simplified by merging the pairwise + and conditional parent finding steps. +- `OCE` result can now be converted to a `SimpleDiGraph` from Graphs.jl. + ## 2.2.1 - `infer_graph` now accepts the `verbose` keyword. diff --git a/docs/src/examples/examples_graphs.md b/docs/src/examples/examples_graphs.md index 86c7043e0..30aec61bf 100644 --- a/docs/src/examples/examples_graphs.md +++ b/docs/src/examples/examples_graphs.md @@ -8,6 +8,7 @@ for the conditional steps. ```@example causalgraph_oce using CausalityTools +using Graphs using Random rng = MersenneTwister(1234) @@ -16,12 +17,16 @@ sys = system(Logistic4Chain(; rng)) x, y, z, w = columns(trajectory(sys, 400, Ttr = 10000)) # Independence tests for unconditional and conditional stages. -utest = SurrogateTest(MIShannon(), KSG2(k = 5)) -ctest = LocalPermutationTest(CMIShannon(), FPVP(k = 5)) +utest = SurrogateTest(MIShannon(), KSG2(k = 3, w = 1); rng, nshuffles = 150) +ctest = LocalPermutationTest(CMIShannon(), MesnerShalisi(k = 3, w = 1); rng, nshuffles = 150) # Infer graph -alg = OCE(; utest, ctest, α = 0.05, τmax = 3) -infer_graph(alg, [x, y, z, w]) +alg = OCE(; utest, ctest, α = 0.05, τmax = 1) +parents = infer_graph(alg, [x, y, z, w]) + +# Convert to graph and inspect edges +g = SimpleDiGraph(parents) +collect(edges(g)) ``` The algorithm nicely recovers the true causal directions. diff --git a/src/CausalityTools.jl b/src/CausalityTools.jl index e97ed3eef..0d5207f14 100644 --- a/src/CausalityTools.jl +++ b/src/CausalityTools.jl @@ -56,28 +56,19 @@ module CausalityTools # Update messages: using Scratch display_update = true - version_number = "2.0.0" + version_number = "2.3.0" update_name = "update_v$(version_number)" update_message = """ \nUpdate message: CausalityTools v$(version_number)\n - - An overall overhaul of the API and documentation. See the online documentation - for a full overview. - - There are now three conceptual levels of funcationality: 1) association measures, - 2) independence testing based on these measures, and 3) causal graph inference. - - Other changes: - - A plethora of new methods and estimators for information theoretic quantities have - been added. See the online documentation for an overview. - - Syntax for many methods have changed. Estimators, which - also contains analysis parameters, are now always the first argument. - - All information-based methods in the DynamicalSystems.jl organization that are - more complex than those in `ComplexityMeasures.jl` have been moved to CausalityTools.jl. - This include `mutualinfo`, `condmutualinfo` and `transferentropy`. - - TransferEntropy.jl has been discontinued, and all its functionality has been moved to - CausalityTools.jl. `conditional_mutualinfo` has been renamed to `condmutualinfo`. - - The `Kraskov1` and `Kraskov2` mutual information estimators have been renamed to - `KraskovStögbauerGrassberger1` (`KSG1` for short) and - `KraskovStögbauerGrassberger2` (`KSG2` for short). + - Significant speed-ups for `OCE` by sorting on maximal measure, thus avoiding + unnecessary significance tests. + - Default parameters for `OCE` default lag parameter have changed. Now, `τmax = 1`, since + that is the only case considered in the original paper. We also use the + `MesnerShalisi` CMI estimator for the conditional step, because in contrast to + the `FPVP` estimator, it has been shown to be consistent. + - Source code for `OCE` has been drastically simplified by merging the pairwise + and conditional parent finding steps. + - `OCE` result can now be converted to a `SimpleDiGraph` from Graphs.jl. """ if display_update diff --git a/src/causal_graphs/causal_graphs.jl b/src/causal_graphs/causal_graphs.jl index 23253f5d9..1b9468a58 100644 --- a/src/causal_graphs/causal_graphs.jl +++ b/src/causal_graphs/causal_graphs.jl @@ -1,3 +1,8 @@ +import Graphs.SimpleGraphs: SimpleDiGraph +import Graphs: edges +export SimpleDiGraph +export edges + include("api.jl") # Concrete implementations diff --git a/src/causal_graphs/oce/OCE.jl b/src/causal_graphs/oce/OCE.jl index 23f525e10..9c87e80c7 100644 --- a/src/causal_graphs/oce/OCE.jl +++ b/src/causal_graphs/oce/OCE.jl @@ -1,11 +1,13 @@ +using Graphs: add_edge! +using Graphs.SimpleGraphs: SimpleDiGraph + export OCE """ OCE <: GraphAlgorithm OCE(; utest::IndependenceTest = SurrogateTest(MIShannon(), KSG2(k = 3, w = 3)), - ctest::C = LocalPermutationTest(CMIShannon(), FPVP(k = 3, w = 3)), - τmax::T = 5, α = 0.05 - ) + ctest::C = LocalPermutationTest(CMIShannon(), MesnerShalisi(k = 3, w = 3)), + τmax::T = 1, α = 0.05) The optimal causation entropy (OCE) algorithm for causal discovery (Sun et al., 2015)[^Sun2015]. @@ -24,13 +26,13 @@ The OCE algorithm has three steps to determine the parents of a variable `xᵢ`. where `P` is the set of parent nodes found in the previous steps. `τmax` indicates the maximum lag `τ` between the target variable `xᵢ(0)` and -its potential parents `xⱼ(-τ)`. +its potential parents `xⱼ(-τ)`. Sun et al. 2015's method is based on `τmax = 1`. ## Returns When used with [`infer_graph`](@ref), it returns a vector `p`, where `p[i]` are the -parents for each input variable. In the future, this will return a labelled, directed -graph with all the detected associations. +parents for each input variable. This result can be converted to a `SimpleDiGraph` +from Graphs.jl (see [example](@ref oce_example)). ## Examples @@ -41,15 +43,18 @@ graph with all the detected associations. causation entropy. SIAM Journal on Applied Dynamical Systems, 14(1), 73-106. """ Base.@kwdef struct OCE{U, C, T} <: GraphAlgorithm - utest::U = SurrogateTest(MIShannon(), KSG2(k = 3, w = 3)) - ctest::C = LocalPermutationTest(CMIShannon(), FPVP(k = 3, w = 3)) - τmax::T = 5 + utest::U = SurrogateTest(MIShannon(), KSG2(k = 3, w = 3), nshuffles = 100) + ctest::C = LocalPermutationTest(CMIShannon(), MesnerShalisi(k = 3, w = 3), nshuffles = 100) + τmax::T = 1 α = 0.05 end function infer_graph(alg::OCE, x; verbose = true) - parents = select_parents(alg, x; verbose) - return parents + return select_parents(alg, x; verbose) +end + +function infer_graph(alg::OCE, x::AbstractDataset; verbose = true) + return infer_graph(alg, columns(x); verbose) end """ @@ -61,17 +66,8 @@ parents of each `xᵢ ∈ x`, assuming that `x` must be integer-indexable, i.e. """ function select_parents(alg::OCE, x; verbose = false) - # Preliminary parents - τs = Iterators.flatten([-1:-1:-alg.τmax |> collect for xᵢ in x]) |> collect - js = Iterators.flatten([fill(i, alg.τmax) for i in eachindex(x)]) |> collect - embeddings = [genembed(xᵢ, -1:-1:-alg.τmax) for xᵢ in x] - T = typeof(1.0) - 𝒫s = Vector{Vector{T}}(undef, 0) - for emb in embeddings - append!(𝒫s, columns(emb)) - end # Find the parents of each variable. - parents = [select_parents(alg, τs, js, 𝒫s, x, k; verbose) for k in eachindex(x)] + parents = [select_parents(alg, x, k; verbose) for k in eachindex(x)] return parents end @@ -89,14 +85,29 @@ function selected(o::OCESelectedParents) return join(["x$(js[i])($(τs[i]))" for i in eachindex(js)], ", ") end - function Base.show(io::IO, x::OCESelectedParents) s = ["x$(x.parents_js[i])($(x.parents_τs[i]))" for i in eachindex(x.parents)] all = "x$(x.i)(0) ← $(join(s, ", "))" show(io, all) end -function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false) +function SimpleDiGraph(v::Vector{<:CausalityTools.OCESelectedParents}) + N = length(v) + g = SimpleDiGraph(N) + for k = 1:N + parents = v[k] + for (j, τ) in zip(parents.parents_js, parents.parents_τs) + if j != k # avoid self-loops + add_edge!(g, j, k) + end + end + end + return g +end + +function select_parents(alg::OCE, x, i::Int; verbose = false) + τs, js, 𝒫s = prepare_embeddings(alg, x, i) + verbose && println("\nInferring parents for x$i(0)...") # Account for the fact that the `𝒫ⱼ ∈ 𝒫s` are embedded. This means that some points are # lost from the `xᵢ`s. @@ -110,158 +121,168 @@ function select_parents(alg::OCE, τs, js, 𝒫s, x, i::Int; verbose = false) # 1. Can we find a significant pairwise association? verbose && println("˧ Querying pairwise associations...") - significant_pairwise = select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) + significant_pairwise = select_parent!(alg, parents, τs, js, 𝒫s, xᵢ, i; verbose) if significant_pairwise verbose && println("˧ Querying new variables conditioned on already selected variables...") # 2. Continue until there are no more significant conditional pairwise associations significant_cond = true - k = 0 while significant_cond - k += 1 - significant_cond = select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) + significant_cond = select_parent!(alg, parents, τs, js, 𝒫s, xᵢ, i; verbose) end ################################################################### # Backward elimination ################################################################### - if !(length(parents.parents) >= 2) - return parents - end + backwards_eliminate!(alg, parents, xᵢ, i; verbose) + end + return parents +end - verbose && println("˧ Backwards elimination...") - eliminate = true - ks_remaining = Set(1:length(parents.parents)) - while eliminate && length(ks_remaining) >= 2 - for k in ks_remaining - eliminate = backwards_eliminate!(parents, alg, xᵢ, k; verbose) - if eliminate - filter!(x -> x == k, ks_remaining) - end - end - end +function prepare_embeddings(alg::OCE, x, i) + # Preliminary parents + τs = Iterators.flatten([-1:-1:-alg.τmax |> collect for xᵢ in x]) |> collect + js = Iterators.flatten([fill(i, alg.τmax) for i in eachindex(x)]) |> collect + embeddings = [genembed(xᵢ, -1:-1:-alg.τmax) for xᵢ in x] + T = typeof(1.0) + 𝒫s = Vector{Vector{T}}(undef, 0) + for emb in embeddings + append!(𝒫s, columns(emb)) end - return parents + return τs, js, 𝒫s end -# Pairwise associations -function select_first_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose = false) - M = length(𝒫s) - if isempty(𝒫s) - return false +function select_parent!(alg::OCE, parents, τs, js, 𝒫s, xᵢ, i::Int; verbose = true) + # Have any parents been identified yet? + pairwise = isempty(parents.parents) + + # If there are no potential parents to pick from, return immediately. + isempty(𝒫s) && return false + + # Configure estimation and independence testing function calls, which differ in the + # number of arguments depending on whether we're doing the pairwise or conditional case. + if !pairwise + P = StateSpaceSet(parents.parents...) + f = (measure, est, xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ, P) + findep = (test, xᵢ, Pix) -> independence(test, xᵢ, Pix, P) + else + f = (measure, est, xᵢ, Pⱼ) -> estimate(measure, est, xᵢ, Pⱼ) + findep = (test, xᵢ, Pix) -> independence(test, xᵢ, Pix) end - # Association measure values and the associated p-values - Is, pvals = zeros(M), zeros(M) - for (i, Pj) in enumerate(𝒫s) - test = independence(alg.utest, xᵢ, Pj) - Is[i] = test.m - pvals[i] = pvalue(test) + # Compute the measure without significance testing first. This avoids unnecessary + # independence testing, which takes a lot of time. + Is = zeros(length(𝒫s)) + for (i, Pⱼ) in enumerate(𝒫s) + Is[i] = f(alg.utest.measure, alg.utest.est, xᵢ, Pⱼ) end - if all(pvals .>= alg.α) - s = ["x$i(0) ⫫ x$j(t$τ) | ∅)" for (τ, j) in zip(τs, js)] - verbose && println("\t$(join(s, "\n\t"))") - return false + # Sort variables according to maximal measure and select the first lagged variable that + # gives significant association with the target variable. + maximize_sortidxs = sortperm(Is, rev = true) + n_checked = 0 + n_potential_vars = length(𝒫s) + while n_checked < n_potential_vars + n_checked += 1 + ix = maximize_sortidxs[n_checked] + if Is[ix] > 0 + # findep takes into account the conditioning set too if it is non-empty. + result = findep(alg.utest, xᵢ, 𝒫s[ix]) + if pvalue(result) < alg.α + if verbose && !pairwise + println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | $(selected(parents))") + elseif verbose && pairwise + println("\tx$i(0) !⫫ x$(js[ix])($(τs[ix])) | ∅") + end + push!(parents.parents, 𝒫s[ix]) + push!(parents.parents_js, js[ix]) + push!(parents.parents_τs, τs[ix]) + deleteat!(𝒫s, ix) + deleteat!(js, ix) + deleteat!(τs, ix) + return true + end + end end - # Select the variable that has the highest significant association with xᵢ. - # "Significant" means a p-value strictly less than the significance level α. - Imax = maximum(Is[pvals .< alg.α]) - idx = findfirst(x -> x == Imax, Is) - - if Is[idx] > 0 - verbose && println("\tx$i(0) !⫫ x$(js[idx])($(τs[idx])) | ∅") - push!(parents.parents, 𝒫s[idx]) - push!(parents.parents_js, js[idx]) - push!(parents.parents_τs, τs[idx]) - deleteat!(𝒫s, idx) - deleteat!(js, idx) - deleteat!(τs, idx) - return true - else + # If we reach this stage, no variables have been selected. Print an informative message. + if verbose && !pairwise + # No more associations were found + s = ["x$i(1) ⫫ x$j($τ) | $(selected(parents)))" for (τ, j) in zip(τs, js)] + println("\t$(join(s, "\n\t"))") + end + if verbose && pairwise s = ["x$i(0) ⫫ x$j($τ) | ∅)" for (τ, j) in zip(τs, js)] - verbose && println("\t$(join(s, "\n\t"))") - return false + println("\t$(join(s, "\n\t"))") end + return false end -function select_conditional_parent!(parents, alg, τs, js, 𝒫s, xᵢ, i; verbose) - if isempty(𝒫s) - return false - end +""" + backwards_eliminate!(alg::OCE, parents::OCESelectedParents, x, i; verbose) - P = StateSpaceSet(parents.parents...) - M = length(𝒫s) - Is = zeros(M) - pvals = zeros(M) - for (i, Pj) in enumerate(𝒫s) - test = independence(alg.ctest, xᵢ, Pj, P) - Is[i] = test.m - pvals[i] = pvalue(test) - end - # Select the variable that has the highest significant association with xᵢ. - # "Significant" means a p-value strictly less than the significance level α. - if all(pvals .>= alg.α) - s = ["x$i(0) ⫫ x$j($τ) | $(selected(parents))" for (τ, j) in zip(τs, js)] - verbose && println("\t$(join(s, "\n\t"))") - return false - end - Imax = maximum(Is[pvals .< alg.α]) - idx = findfirst(x -> x == Imax, Is) - - if Is[idx] > 0 - τ = τs[idx] - j = js[idx] - verbose && println("\tx$i(0) !⫫ x$j($τ) | $(selected(parents))") - push!(parents.parents, 𝒫s[idx]) - push!(parents.parents_js, js[idx]) - push!(parents.parents_τs, τs[idx]) - deleteat!(𝒫s, idx) - deleteat!(τs, idx) - deleteat!(js, idx) - return true - else - s = ["x$i(1) ⫫ x$j($τ) | $(selected(parents)))" for (τ, j) in zip(τs, js)] - verbose && println("\t$(join(s, "\n\t"))") - return false +Algorithm 2.2 in Sun et al. (2015). Perform backward elimination for the `i`-th variable +in `x`, given the previously inferred `parents`, which were deduced using parameters in +`alg`. Modifies `parents` in-place. +""" +function backwards_eliminate!(alg::OCE, parents::OCESelectedParents, xᵢ, i::Int; verbose) + length(parents.parents) < 2 && return parents + + verbose && println("˧ Backwards elimination...") + n_initial = length(parents.parents_js) + q = 0 + variable_was_eliminated = true + while variable_was_eliminated && length(parents.parents_js) >= 2 && q < n_initial + q += 1 + variable_was_eliminated = eliminate_loop!(alg, parents, xᵢ, i; verbose) end + return parents end -function backwards_eliminate!(parents, alg, xᵢ, k; verbose = false) +""" + eliminate_loop!(alg::OCE, parents::OCESelectedParents, xᵢ; verbose = false) + +Inner portion of algorithm 2.2 in Sun et al. (2015). This method is called in an external +while-loop that handles the variable elimination step in their line 3. +""" +function eliminate_loop!(alg::OCE, parents::OCESelectedParents, xᵢ, i; verbose = false) + isempty(parents.parents) && return false M = length(parents.parents) P = parents.parents - Pj = P[k] - remaining_idxs = setdiff(1:M, k) - remaining = StateSpaceSet(P...)[:, remaining_idxs] - test = independence(alg.ctest, xᵢ, Pj, remaining) - - if verbose - τ, j = parents.parents_τs[k], parents.parents_js[k] # Variable currently considered - τs = parents.parents_τs - js = parents.parents_js - src_var = "x$j($τ)" - targ_var = "x$(js[k])($(τs[k]))" - cond_var = join(["x$(js[i])($(τs[i]))" for i in remaining_idxs], ", ") + variable_was_eliminated = false + for k in eachindex(P) + Pj = P[k] + remaining_idxs = setdiff(1:M, k) + remaining = StateSpaceSet(P[remaining_idxs]...) + test = independence(alg.ctest, xᵢ, Pj, remaining) + + if verbose + τ, j = parents.parents_τs[k], parents.parents_js[k] # Variable currently considered + τs = parents.parents_τs + js = parents.parents_js + src_var = "x$j($τ)" + targ_var = "x$i(0)" + cond_var = join(["x$(js[r])($(τs[r]))" for r in remaining_idxs], ", ") + + if test.pvalue >= alg.α + outcome_msg = "Removing x$(j)($τ) from parent set" + println("\t$src_var ⫫ $targ_var | $cond_var → $outcome_msg") + else + outcome_msg = "Keeping x$(j)($τ) in parent set" + println("\t$src_var !⫫ $targ_var | $cond_var → $outcome_msg") + end + end + # A parent became independent of the target conditional on the remaining parents if test.pvalue >= alg.α - outcome_msg = "Removing x$(j)($τ) from parent set" - println("\t$src_var ⫫ $targ_var | $cond_var → $outcome_msg") - else - outcome_msg = "Keeping x$(j)($τ) in parent set" - println("\t$src_var !⫫ $targ_var | $cond_var → $outcome_msg") + deleteat!(parents.parents, k) + deleteat!(parents.parents_js, k) + deleteat!(parents.parents_τs, k) + variable_was_eliminated = true + break end end - # If p-value >= α, then we can't reject the null, i.e. the statistic I is - # indistinguishable from zero, so we claim independence and remove the variable. - if test.pvalue >= alg.α - deleteat!(parents.parents, k) - deleteat!(parents.parents_js, k) - deleteat!(parents.parents_τs, k) - return true - else - return false - end + return variable_was_eliminated end diff --git a/test/causal_graphs/oce.jl b/test/causal_graphs/oce.jl index 0aded76b8..8a5225ae1 100644 --- a/test/causal_graphs/oce.jl +++ b/test/causal_graphs/oce.jl @@ -1,15 +1,40 @@ using CausalityTools +using CausalityTools: OCESelectedParents using Test using StableRNGs +using Graphs.SimpleGraphs: SimpleEdge rng = StableRNG(123) sys = system(Logistic4Chain(; rng)) -X = columns(trajectory(sys, 350, Ttr = 10000)) - -parents = infer_graph(OCE(τmax = 2), X) -@test all(x ∉ parents[1].parents_js for x in (2, 3, 4)) -@test all(x ∉ parents[2].parents_js for x in (3, 4)) -@test all(x ∉ parents[3].parents_js for x in (4)) -@test 1 ∈ parents[2].parents_js -@test 2 ∈ parents[3].parents_js -@test 3 ∈ parents[4].parents_js +X = columns(trajectory(sys, 60, Ttr = 10000)) +utest = SurrogateTest(MIShannon(), KSG1(k = 10, w = 1); rng, nshuffles = 30) +ctest = LocalPermutationTest(CMIShannon(), MesnerShalisi(k = 10, w = 1); rng, nshuffles = 30) +alg = OCE(; utest, ctest, τmax = 2) +parents = infer_graph(alg, X; verbose = true) +@test parents isa Vector{<:OCESelectedParents} +@test SimpleDiGraph(parents) isa SimpleDiGraph + +rng = StableRNG(123) +sys = system(Logistic2Bidir(; rng)) +X = columns(trajectory(sys, 300, Ttr = 10000)) +utest = SurrogateTest(MIShannon(), KSG1(k = 10, w = 1); rng, nshuffles = 100) +ctest = LocalPermutationTest(CMIShannon(), MesnerShalisi(k = 10, w = 1); rng, nshuffles = 100) +parents = infer_graph(OCE(; utest, ctest, τmax = 1), X; verbose = true) +@test parents isa Vector{<:OCESelectedParents} +g = SimpleDiGraph(parents) +@test g isa SimpleDiGraph + +# "Analytical" test: check that we at least identify one true positive. There may +# be several false positives, but there's no way of telling a priori how many. +function at_least_one_true_positive(true_edges, estimated_graph) + estimated_edges = edges(estimated_graph) + at_least_one_tp = false + for e in true_edges + if e in estimated_edges + at_least_one_tp = true + end + end + return at_least_one_tp +end + +@test at_least_one_true_positive([SimpleEdge(1, 2), SimpleEdge(2, 1)], g)