diff --git a/Project.toml b/Project.toml index 984802d..a7fc4cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,21 @@ name = "PersistenceDiagrams" uuid = "90b4794c-894b-4756-a0f8-5efeb5ddf7ae" authors = ["mtsch "] -version = "0.9.6" +version = "0.9.7" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Hungarian = "e91730f6-4275-51fb-a7a0-7064cfbd3b39" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -PersistenceDiagramsBase = "b1ad91c1-539c-4ace-90bd-ea06abc420fa" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] -Compat = "^3.10.0" +Compat = "^3.10, 4" Hungarian = "0.6" MLJModelInterface = "1" -PersistenceDiagramsBase = "^0.1.1" RecipesBase = "1" ScientificTypes = "3" Tables = "1" @@ -25,6 +23,7 @@ julia = "1" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -32,4 +31,4 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Documenter", "MLJBase", "SafeTestsets", "Suppressor", "Test"] +test = ["Aqua", "DataFrames", "Documenter", "MLJBase", "SafeTestsets", "Suppressor", "Test"] diff --git a/src/PersistenceDiagrams.jl b/src/PersistenceDiagrams.jl index a06c80d..919add9 100644 --- a/src/PersistenceDiagrams.jl +++ b/src/PersistenceDiagrams.jl @@ -40,12 +40,14 @@ export PersistenceDiagram, using Compat using Hungarian -using PersistenceDiagramsBase using RecipesBase using ScientificTypes using Statistics using Tables +include("intervals.jl") +include("diagrams.jl") +include("tables.jl") include("matching.jl") include("persistencecurves.jl") diff --git a/src/diagrams.jl b/src/diagrams.jl new file mode 100644 index 0000000..671c58b --- /dev/null +++ b/src/diagrams.jl @@ -0,0 +1,137 @@ +""" + PersistenceDiagram <: AbstractVector{PersistenceInterval} + +Type for representing persistence diagrams. Behaves exactly like a vector of +`PersistenceInterval`s, but can have additional metadata attached to it. It supports pretty +printing and plotting. + +Can be used as a table with any function that uses the +[`Tables.jl`](https://github.com/JuliaData/Tables.jl) interface. Note that using it as a +table will only keep interval endpoints and the `dim` and `threshold` attributes. + +# Example + +```jldoctest +julia> diagram = PersistenceDiagram([(1, 3), (3, 4), (1, Inf)]; dim=1, custom_metadata=:a) +3-element 1-dimensional PersistenceDiagram: + [1.0, 3.0) + [3.0, 4.0) + [1.0, ∞) + +julia> diagram[1] +[1.0, 3.0) + +julia> sort(diagram; by=persistence, rev=true) +3-element 1-dimensional PersistenceDiagram: + [1.0, ∞) + [1.0, 3.0) + [3.0, 4.0) + +julia> propertynames(diagram) +(:intervals, :dim, :custom_metadata) + +julia> dim(diagram) +1 + +julia> diagram.custom_metadata +:a +``` +""" +struct PersistenceDiagram <: AbstractVector{PersistenceInterval} + intervals::Vector{PersistenceInterval} + meta::NamedTuple +end + +function PersistenceDiagram(intervals::Vector{PersistenceInterval}; kwargs...) + meta = (; kwargs...) + return PersistenceDiagram(intervals, meta) +end +function PersistenceDiagram(intervals::AbstractVector{PersistenceInterval}; kwargs...) + return PersistenceDiagram(collect(intervals); kwargs...) +end +function PersistenceDiagram(pairs::AbstractVector{<:Tuple}; kwargs...) + return PersistenceDiagram(PersistenceInterval.(pairs); kwargs...) +end +function PersistenceDiagram(table) + rows = Tables.rows(table) + if isempty(rows) + return PersistenceDiagram(PersistenceInterval[]) + else + firstrow = first(rows) + dim = hasproperty(firstrow, :dim) ? firstrow.dim : missing + threshold = hasproperty(firstrow, :threshold) ? firstrow.threshold : missing + intervals = map(rows) do row + d = hasproperty(row, :dim) ? row.dim : missing + t = hasproperty(row, :threshold) ? row.threshold : missing + if !isequal(d, dim) + error("different `dim`s detected. Try splitting the table first.") + end + if !isequal(t, threshold) + error("different `threshold`s detected. Try splitting the table first.") + end + PersistenceInterval(row.birth, row.death) + end + return PersistenceDiagram(intervals; dim=dim, threshold=threshold) + end +end + +function Base.show(io::IO, diag::PersistenceDiagram) + return summary(io, diag) +end +function Base.summary(io::IO, diag::PersistenceDiagram) + if haskey(diag.meta, :dim) + print(io, length(diag), "-element ", dim(diag), "-dimensional PersistenceDiagram") + else + print(io, length(diag), "-element PersistenceDiagram") + end +end + +### +### Array interface +### +Base.size(diag::PersistenceDiagram) = size(diag.intervals) +Base.getindex(diag::PersistenceDiagram, i::Integer) = diag.intervals[i] +Base.setindex!(diag::PersistenceDiagram, x, i::Integer) = diag.intervals[i] = x +Base.firstindex(diag::PersistenceDiagram) = 1 +Base.lastindex(diag::PersistenceDiagram) = length(diag.intervals) + +function Base.similar(diag::PersistenceDiagram) + return PersistenceDiagram(similar(diag.intervals); diag.meta...) +end +function Base.similar(diag::PersistenceDiagram, dims::Tuple) + return PersistenceDiagram(similar(diag.intervals, dims); diag.meta...) +end + +### +### Meta +### +function Base.getproperty(diag::PersistenceDiagram, key::Symbol) + if hasfield(typeof(diag), key) + return getfield(diag, key) + elseif haskey(diag.meta, key) + return diag.meta[key] + else + error("$diag has no $key") + end +end +function Base.propertynames(diag::PersistenceDiagram, private::Bool=false) + if private + return tuple(:intervals, :meta, propertynames(diag.meta)...) + else + return tuple(:intervals, propertynames(diag.meta)...) + end +end + +""" + threshold(diagram::PersistenceDiagram) + +Get the threshold of persistence diagram. Equivalent to `diagram.threshold`. +""" +threshold(diag::PersistenceDiagram) = diag.threshold + +""" + dim(diagram::PersistenceDiagram) + +Get the dimension of persistence diagram. Equivalent to `diagram.dim`. +""" +dim(diag::PersistenceDiagram) = diag.dim diff --git a/src/intervals.jl b/src/intervals.jl new file mode 100644 index 0000000..efcc552 --- /dev/null +++ b/src/intervals.jl @@ -0,0 +1,182 @@ +""" + PersistenceInterval + +Type for representing persistence intervals. It behaves exactly like a `Tuple{Float64, +Float64}`, but can have meta data attached to it. The metadata is accessible with +`getproperty` or the dot syntax. + +# Example + +```jldoctest +julia> interval = PersistenceInterval(1, Inf; meta1=:a, meta2=:b) +[1.0, ∞) with: + meta1: Symbol + meta2: Symbol + +julia> birth(interval), death(interval), persistence(interval) +(1.0, Inf, Inf) + +julia> isfinite(interval) +false + +julia> propertynames(interval) +(:birth, :death, :meta1, :meta2) + +julia> interval.meta1 +:a +``` +""" +struct PersistenceInterval + birth::Float64 + death::Float64 + meta::NamedTuple +end +function PersistenceInterval(birth, death; kwargs...) + meta = (; kwargs...) + return PersistenceInterval(Float64(birth), Float64(death), meta) +end +function PersistenceInterval(t::Tuple{<:Any,<:Any}; kwargs...) + meta = (; kwargs...) + return PersistenceInterval(Float64(t[1]), Float64(t[2]), meta) +end +function PersistenceInterval(int::PersistenceInterval; kwargs...) + meta = (; kwargs...) + return PersistenceInterval(Float64(int[1]), Float64(int[2]), meta) +end + +""" + birth(interval) + +Get the birth time of `interval`. +""" +birth(int::PersistenceInterval) = getfield(int, 1) +""" + death(interval) + +Get the death time of `interval`. +""" +death(int::PersistenceInterval) = getfield(int, 2) +""" + persistence(interval) + +Get the persistence of `interval`, which is equal to `death - birth`. +""" +persistence(int::PersistenceInterval) = death(int) - birth(int) + +""" + midlife(interval) + +Get the midlife of the `interval`, which is equal to `(birth + death) / 2`. +""" +midlife(int::PersistenceInterval) = (birth(int) + death(int)) / 2 + +Base.isfinite(int::PersistenceInterval) = isfinite(death(int)) + +### +### Iteration +### +function Base.iterate(int::PersistenceInterval, i=1) + if i == 1 + return birth(int), i + 1 + elseif i == 2 + return death(int), i + 1 + else + return nothing + end +end + +Base.length(::PersistenceInterval) = 2 +Base.IteratorSize(::Type{<:PersistenceInterval}) = Base.HasLength() +Base.IteratorEltype(::Type{<:PersistenceInterval}) = Base.HasEltype() +Base.eltype(::Type{<:PersistenceInterval}) = Float64 + +function Base.getindex(int::PersistenceInterval, i) + if i == 1 + return birth(int) + elseif i == 2 + return death(int) + else + throw(BoundsError(int, i)) + end +end + +Base.firstindex(int::PersistenceInterval) = 1 +Base.lastindex(int::PersistenceInterval) = 2 + +### +### Equality and ordering +### +function Base.:(==)(int1::PersistenceInterval, int2::PersistenceInterval) + return birth(int1) == birth(int2) && death(int1) == death(int2) +end +Base.:(==)(int::PersistenceInterval, (b, d)::Tuple) = birth(int) == b && death(int) == d +Base.:(==)((b, d)::Tuple, int::PersistenceInterval) = birth(int) == b && death(int) == d + +function Base.isless(int1::PersistenceInterval, int2::PersistenceInterval) + return (birth(int1), death(int1)) < (birth(int2), death(int2)) +end + +### +### Printing +### +function Base.show(io::IO, int::PersistenceInterval) + b = round(birth(int); sigdigits=3) + d = isfinite(death(int)) ? round(death(int); sigdigits=3) : "∞" + return print(io, "[$b, $d)") +end + +function Base.show(io::IO, ::MIME"text/plain", int::PersistenceInterval) + b = round(birth(int); sigdigits=3) + d = isfinite(death(int)) ? round(death(int); sigdigits=3) : "∞" + print(io, "[$b, $d)") + if !isempty(int.meta) + print(io, " with:") + for (k, v) in zip(keys(int.meta), int.meta) + print(io, "\n ", k, ": ", summary(v)) + end + end +end + +### +### Metadata +### +function Base.getproperty(int::PersistenceInterval, key::Symbol) + if hasfield(typeof(int), key) + return getfield(int, key) + elseif haskey(int.meta, key) + return int.meta[key] + else + error("interval $int has no $key") + end +end +function Base.propertynames(int::PersistenceInterval, private::Bool=false) + if private + return tuple(:birth, :death, propertynames(int.meta)..., :meta) + else + return (:birth, :death, propertynames(int.meta)...) + end +end + +""" + representative(interval::PersistenceInterval) + +Get the representative (co)cycle attached to `interval`, if it has one. +""" +representative(int::PersistenceInterval) = int.representative + +""" + birth_simplex(interval::PersistenceInterval) + +Get the critical birth simplex of `interval`, if it has one. +""" +birth_simplex(int::PersistenceInterval) = int.birth_simplex + +""" + death_simplex(interval::PersistenceInterval) + +Get the critical death simplex of `interval`, if it has one. + +!!! note + An infinite interval's death simplex is `nothing`. +""" +death_simplex(int::PersistenceInterval) = int.death_simplex diff --git a/src/persistencecurves.jl b/src/persistencecurves.jl index f894941..53626c8 100644 --- a/src/persistencecurves.jl +++ b/src/persistencecurves.jl @@ -272,7 +272,7 @@ persistence diagrams. [arXiv preprint arXiv:1904.07768](https://arxiv.org/abs/19 function Midlife(args...; kwargs...) return PersistenceCurve(midlife, sum, args...; kwargs...) end -PersistenceDiagramsBase.midlife((b, d), _, _) = (b + d) / 2 +midlife((b, d), _, _) = (b + d) / 2 """ LifeEntropy diff --git a/src/tables.jl b/src/tables.jl new file mode 100644 index 0000000..ec927b8 --- /dev/null +++ b/src/tables.jl @@ -0,0 +1,31 @@ +struct PersistenceDiagramRowIterator{D} + diagram::D +end + +function Base.iterate(it::PersistenceDiagramRowIterator, i=1) + if i > length(it.diagram) + return nothing + else + int = it.diagram[i] + dim = get(it.diagram.meta, :dim, missing) + threshold = get(it.diagram.meta, :threshold, missing) + return (birth=int.birth, death=int.death, dim=dim, threshold=threshold), i + 1 + end +end + +Base.IteratorSize(::PersistenceDiagramRowIterator) = Base.HasLength() +Base.length(it::PersistenceDiagramRowIterator) = length(it.diagram) + +Tables.istable(::Type{<:PersistenceDiagram}) = true +Tables.rowaccess(::Type{<:PersistenceDiagram}) = true +function Tables.rows(diagram::PersistenceDiagram) + return PersistenceDiagramRowIterator(diagram) +end +function Tables.schema(it::PersistenceDiagramRowIterator) + diagram = it.diagram + D = hasproperty(diagram, :dim) ? Int : Missing + T = hasproperty(diagram, :threshold) ? Float64 : Missing + return Tables.Schema((:birth, :death, :dim, :threshold), (Float64, Float64, D, T)) +end + +Tables.materializer(::PersistenceDiagram) = PersistenceDiagram diff --git a/test/diagrams.jl b/test/diagrams.jl new file mode 100644 index 0000000..5e52668 --- /dev/null +++ b/test/diagrams.jl @@ -0,0 +1,237 @@ +using PersistenceDiagrams + +using Compat +using DataFrames +using Test + +@testset "PersistenceInterval" begin + int1 = PersistenceInterval(1, 2) + int2 = PersistenceInterval(1, Inf) + r = [1, 2, 6, 3] + int3 = PersistenceInterval(1, 2; birth_simplex=:σ, death_simplex=:τ, representative=r) + int4 = PersistenceInterval(int1; birth_simplex=:σ, death_simplex=:τ, representative=r) + + @testset "Equality, order" begin + @test int1 ≠ int2 + @test int1 == int3 + @test int1 < int2 + @test isless(int3, int2) + end + + @testset "Comparison with tuples" begin + @test int1 == (1, 2) + @test int2 == (1, Inf) + @test (1.0, 2.0) == int4 + @test int1 == int3 + + @test PersistenceInterval((1, 2)) == int1 + end + + @testset "Birth, death, iteration" begin + @test int1[1] == birth(int1) == 1 + @test int1[2] == death(int1) == 2 + @test int3[1] == birth(int1) == 1 + @test int4[2] == death(int1) == 2 + @test isfinite(int1) + @test !isfinite(int2) + @test persistence(int1) == 1 + @test persistence(int2) == Inf + @test midlife(int1) == 1.5 + @test midlife(int2) == Inf + + @test eltype(int1) ≡ Float64 + @test eltype(PersistenceInterval) ≡ Float64 + @test length(int1) == 2 + @test collect(int1) == [1, 2] + @test tuple(int1...) ≡ (1.0, 2.0) + @test firstindex(int1) == 1 + @test lastindex(int1) == 2 + @test first(int1) == 1 + @test last(int1) == 2 + @test Base.IteratorSize(int1) == Base.HasLength() + @test Base.IteratorEltype(int1) == Base.HasEltype() + + @test_throws BoundsError int2[0] + @test_throws BoundsError int2[3] + end + + @testset "Metadata access" begin + @test int1.meta == NamedTuple() + @test birth_simplex(int3) == :σ + @test int4.birth_simplex == :σ + @test death_simplex(int4) == :τ + @test int3.death_simplex == :τ + @test representative(int3) == r + + @test propertynames(int1) == (:birth, :death) + @test propertynames(int1, true) == (:birth, :death, :meta) + @test propertynames(int3) == + (:birth, :death, :birth_simplex, :death_simplex, :representative) + @test propertynames(int3, true) == + (:birth, :death, :birth_simplex, :death_simplex, :representative, :meta) + + @test_throws ErrorException birth_simplex(int1) + @test_throws ErrorException death_simplex(int1) + @test_throws ErrorException representative(int1) + @test_throws ErrorException int3.something + end + + @testset "Printing" begin + @test sprint(print, int1) == "[1.0, 2.0)" + @test sprint(print, int2) == "[1.0, ∞)" + @test sprint(print, int3) == "[1.0, 2.0)" + + print_text_plain(io, x) = show(io, MIME"text/plain"(), x) + @test sprint(print_text_plain, int1) == "[1.0, 2.0)" + @test sprint(print_text_plain, int2) == "[1.0, ∞)" + @test sprint(print_text_plain, int3) == + "[1.0, 2.0) with:\n" * + " birth_simplex: Symbol\n" * + " death_simplex: Symbol\n" * + " representative: 4-element $(typeof(r))" + end +end + +@testset "PersistenceDiagram" begin + diagram1 = PersistenceDiagram( + [ + PersistenceInterval(1, 3; a=1), + PersistenceInterval(3, 4; a=2), + PersistenceInterval(3, Inf; a=3), + ]; + dim=1, + ) + diagram2 = PersistenceDiagram([(1, 3), (3, 4), (3, Inf)]; threshold=0.3) + diagram3 = PersistenceDiagram( + view( + [ + PersistenceInterval(1, 2), + PersistenceInterval(3, 4), + PersistenceInterval(3, Inf), + ], + 1:3, + ); + a=1, + ) + + @testset "A persistence diagram is an array" begin + @test diagram1[1] == (1, 3) + @test diagram2[2] == (3, 4) + @test diagram3[3] == (3, Inf) + + @test_throws BoundsError diagram1[0] + @test_throws BoundsError diagram2[4] + + @test diagram1 == diagram2 + @test diagram1 == [(1, 3), (3, 4), (3, Inf)] + + @test length(diagram1) == 3 + @test firstindex(diagram2) == 1 + @test lastindex(diagram3) == 3 + @test length(diagram1) == 3 + @test size(diagram2) == (3,) + + @test first(diagram1) == (1, 3) + @test last(diagram2) == (3, Inf) + + @test copy(diagram1) == diagram1 + @test copy(diagram2).threshold == 0.3 + + @test similar(diagram1) isa typeof(diagram1) + @test similar(diagram1).dim == 1 + @test similar(diagram3, (Base.OneTo(2),)) isa typeof(diagram3) + @test similar(diagram3, (Base.OneTo(2),)).a == 1 + + @test sort(diagram3) isa typeof(diagram3) + @test sort(diagram2; by=death, rev=true) == [(3, Inf), (3, 4), (1, 3)] + @test sort(diagram3; by=death, rev=true).a == 1 + end + + @testset "Metadata access" begin + @test dim(diagram1) == diagram1.dim == 1 + @test threshold(diagram2) == diagram2.threshold == 0.3 + @test diagram3.a == 1 + + @test propertynames(diagram1) == (:intervals, :dim) + @test propertynames(diagram1, true) == (:intervals, :meta, :dim) + + @test_throws ErrorException diagram1.threshold + @test_throws ErrorException diagram2.dim + + @test diagram1[1].a == 1 + @test diagram1[2].a == 2 + @test diagram1[3].a == 3 + end + + if VERSION ≥ v"1.5.0" + @testset "Printing" begin + print_text_plain(io, x) = show(io, MIME"text/plain"(), x) + + @test sprint(print, diagram1) == "3-element 1-dimensional PersistenceDiagram" + @test sprint(print_text_plain, diagram1) == + "3-element 1-dimensional PersistenceDiagram:\n" * + " [1.0, 3.0)\n" * + " [3.0, 4.0)\n" * + " [3.0, ∞)" + + @test sprint(print, diagram2) == "3-element PersistenceDiagram" + @test sprint(print_text_plain, diagram2) == + "3-element PersistenceDiagram:\n" * + " [1.0, 3.0)\n" * + " [3.0, 4.0)\n" * + " [3.0, ∞)" + end + end +end + +@testset "Tables.jl interface" begin + diag1 = PersistenceDiagram( + [PersistenceInterval(1, 2), PersistenceInterval(1, 3)]; dim=0, threshold=4 + ) + + df = DataFrame(diag1) + @test names(df) == ["birth", "death", "dim", "threshold"] + @test nrow(df) == 2 + @test PersistenceDiagram(df) == diag1 + + diag2 = PersistenceDiagram( + [ + PersistenceInterval(1, 2; a=nothing), + PersistenceInterval(1, 3; a=nothing), + PersistenceInterval(1, 4; a=1), + ]; + dim=1, + b=2, + ) + + df = DataFrame(diag2) + @test names(df) == ["birth", "death", "dim", "threshold"] + @test nrow(df) == 3 + @test all(ismissing, df.threshold) + + table = Tables.columntable((threshold=[1, 1], birth=[0, 0], death=[0, 0])) + diagram = PersistenceDiagram(table) + @test ismissing(dim(diagram)) + @test threshold(diagram) == 1 + + table = Tables.columntable((dim=[1, 1], birth=[0, 0], death=[0, 0])) + diagram = PersistenceDiagram(table) + @test ismissing(threshold(diagram)) + @test dim(diagram) == 1 + + table = Tables.columntable((birth=[], death=[])) + @test isempty(PersistenceDiagram(table)) + + @test_throws ErrorException PersistenceDiagram( + Tables.columntable((birth=[1, 1], death=[2, 2], threshold=[1, 2])) + ) + @test_throws ErrorException PersistenceDiagram( + Tables.columntable((birth=[1, 1], death=[2, 2], dim=[1, 2])) + ) + @test_throws ErrorException PersistenceDiagram( + Tables.columntable((death=[2, 2], dim=[1, 1])) + ) + @test_throws ErrorException PersistenceDiagram( + Tables.columntable((birth=[2, 2], dim=[1, 1])) + ) +end diff --git a/test/doctests.jl b/test/doctests.jl index 51d7985..15740dc 100644 --- a/test/doctests.jl +++ b/test/doctests.jl @@ -2,8 +2,8 @@ using Documenter using PersistenceDiagrams using Test -if VERSION ≥ v"1.7-DEV" || VERSION < v"1.6-DEV" - @warn "Doctests were set up on Julia v1.6. Skipping." +if VERSION ≥ v"1.8-DEV" || VERSION < v"1.7-DEV" + @warn "Doctests were set up on Julia v1.7. Skipping." else DocMeta.setdocmeta!( PersistenceDiagrams, :DocTestSetup, :(using PersistenceDiagrams); recursive=true diff --git a/test/runtests.jl b/test/runtests.jl index ca03c31..44748a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,9 @@ using SafeTestsets using Test +@safetestset "diagrams" begin + include("diagrams.jl") +end @safetestset "matching" begin include("matching.jl") end