diff --git a/src/constraint_tree.jl b/src/constraint_tree.jl index 20bd744..d8637d4 100644 --- a/src/constraint_tree.jl +++ b/src/constraint_tree.jl @@ -63,7 +63,7 @@ $(TYPEDFIELDS) """ Base.@kwdef struct ConstraintTree "Sorted dictionary of elements of the constraint tree." - elems::SortedDict{Symbol,Union{Constraint,QConstraint,ConstraintTree}} + elems::SortedDict{Symbol,Union{Constraint,QConstraint,ConstraintTree}} = SortedDict() ConstraintTree(x::SortedDict{Symbol,Union{Constraint,QConstraint,ConstraintTree}}) = new(x) @@ -108,11 +108,15 @@ function Base.getproperty(x::ConstraintTree, sym::Symbol) elems(x)[sym] end +Base.isempty(x::ConstraintTree) = isempty(elems(x)) + +Base.length(x::ConstraintTree) = length(elems(x)) + Base.keys(x::ConstraintTree) = keys(elems(x)) -Base.values(x::ConstraintTree) = values(elems(x)) +Base.haskey(x::ConstraintTree, sym::Symbol) = haskey(elems(x), sym) -Base.length(x::ConstraintTree) = length(elems(x)) +Base.values(x::ConstraintTree) = values(elems(x)) Base.iterate(x::ConstraintTree) = iterate(elems(x)) Base.iterate(x::ConstraintTree, st) = iterate(elems(x), st) @@ -121,6 +125,8 @@ Base.eltype(x::ConstraintTree) = eltype(elems(x)) Base.propertynames(x::ConstraintTree) = keys(x) +Base.hasproperty(x::ConstraintTree, sym::Symbol) = haskey(x, sym) + Base.getindex(x::ConstraintTree, sym::Symbol) = getindex(elems(x), sym) # diff --git a/src/solution_tree.jl b/src/solution_tree.jl index 0c1b979..ec348c1 100644 --- a/src/solution_tree.jl +++ b/src/solution_tree.jl @@ -26,7 +26,9 @@ SolutionTree(cs, vals) $(TYPEDFIELDS) """ Base.@kwdef struct SolutionTree - elems::SortedDict{Symbol,Union{Float64,SolutionTree}} + elems::SortedDict{Symbol,Union{Float64,SolutionTree}} = SortedDict() + + SolutionTree(x...) = new(x...) SolutionTree(x::Constraint, vars::AbstractVector{Float64}) = value_product(x.value, vars) SolutionTree(x::QConstraint, vars::AbstractVector{Float64}) = @@ -58,11 +60,15 @@ function Base.getproperty(x::SolutionTree, sym::Symbol) elems(x)[sym] end +Base.isempty(x::SolutionTree) = isempty(elems(x)) + +Base.length(x::SolutionTree) = length(elems(x)) + Base.keys(x::SolutionTree) = keys(elems(x)) -Base.values(x::SolutionTree) = values(elems(x)) +Base.haskey(x::SolutionTree, sym::Symbol) = haskey(elems(x), sym) -Base.length(x::SolutionTree) = length(elems(x)) +Base.values(x::SolutionTree) = values(elems(x)) Base.iterate(x::SolutionTree) = iterate(elems(x)) Base.iterate(x::SolutionTree, st) = iterate(elems(x), st) @@ -71,4 +77,6 @@ Base.eltype(x::SolutionTree) = eltype(elems(x)) Base.propertynames(x::SolutionTree) = keys(elems(x)) +Base.hasproperty(x::SolutionTree, sym::Symbol) = haskey(x, sym) + Base.getindex(x::SolutionTree, sym::Symbol) = getindex(elems(x), sym) diff --git a/test/misc.jl b/test/misc.jl index 7ed7882..3f6de63 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -41,6 +41,13 @@ end ct1 = C.variables(keys = [:a, :b]) ct2 = C.variables(keys = [:c, :d]) + @test isempty(C.ConstraintTree()) + @test !isempty(ct1) + @test haskey(ct1, :a) + @test hasproperty(ct1, :a) + @test !haskey(ct1, :c) + @test !hasproperty(ct1, :c) + @test collect(propertynames(ct1)) == [:a, :b] @test [k for (k, _) in ct2] == [:c, :d] @test eltype(ct2) == Pair{Symbol,C.ConstraintTreeElem} @@ -54,6 +61,16 @@ end ct = C.variables(keys = [:a, :b]) @test_throws BoundsError C.SolutionTree(ct, [1.0]) st = C.SolutionTree(ct, [123.0, 321.0]) + + @test isempty(C.SolutionTree()) + @test isempty(C.SolutionTree(C.ConstraintTree(), Float64[])) + @test !isempty(st) + @test haskey(st, :a) + @test hasproperty(st, :a) + @test !haskey(st, :c) + @test !hasproperty(st, :c) + + @test length(ct) == length(st) @test st.a == 123.0 @test st[:b] == 321.0 @test collect(propertynames(st)) == [:a, :b]