Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve support for custom Number types #50

Merged
merged 11 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ julia = "1"
[extras]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Dates", "Documenter"]
test = ["Dates", "Documenter", "Unitful"]
49 changes: 23 additions & 26 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,22 @@

Inverse of `sqrt(x)` for non-negative `x`.
"""
square(x) = x^2
function square(x::Real)
x < zero(x) && throw(DomainError(x, "`square` is defined as the inverse of `sqrt` and can only be evaluated for non-negative values"))
function square(x)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
oschulz marked this conversation as resolved.
Show resolved Hide resolved
if is_real_type(typeof(x)) && x < zero(x)
throw(DomainError(x, "`square` is defined as the inverse of `sqrt` and can only be evaluated for non-negative values"))
aplavin marked this conversation as resolved.
Show resolved Hide resolved
end
return x^2
end


function invpow2(x::Real, p::Integer)
oschulz marked this conversation as resolved.
Show resolved Hide resolved
if x ≥ zero(x) || isodd(p)
copysign(abs(x)^inv(p), x)
function invpow2(x::Number, p::Real)
aplavin marked this conversation as resolved.
Show resolved Hide resolved
if is_real_type(typeof(x))
x ≥ zero(x) ? x^inv(p) : # x > 0 - trivially invertible
isinteger(p) && isodd(Integer(p)) ? copysign(abs(x)^inv(p), x) : # p odd - invertible even for x < 0
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
else
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
end
end
function invpow2(x::Real, p::Real)
if x ≥ zero(x)
x^inv(p)
else
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
end
end
function invpow2(x, p::Real)
# complex x^p is only invertible for p = 1/n
if isinteger(inv(p))
x^inv(p)
else
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
# complex x^p is invertible only for p = 1/n
isinteger(inv(p)) ? x^inv(p) : throw(DomainError(x, "inverse for x^$p is not defined at $x"))
end
end

Expand All @@ -50,12 +39,12 @@ function invlog1(b::Real, x::Real)
throw(DomainError(x, "inverse for log($b, x) is not defined at $x"))
end
end
invlog1(b, x) = b^x
invlog1(b::Number, x::Number) = b^x

invlog2(b, x) = x^inv(b)
invlog2(b::Number, x::Number) = x^inv(b)


function invdivrem((q, r), divisor)
function invdivrem((q, r)::NTuple{2,Number}, divisor::Number)
res = muladd(q, divisor, r)
if abs(r) ≤ abs(divisor) && (iszero(r) || sign(r) == sign(res))
res
Expand All @@ -64,10 +53,18 @@ function invdivrem((q, r), divisor)
end
end

function invfldmod((q, r), divisor)
function invfldmod((q, r)::NTuple{2,Number}, divisor::Number)
if abs(r) ≤ abs(divisor) && (iszero(r) || sign(r) == sign(divisor))
muladd(q, divisor, r)
else
throw(DomainError((q, r), "inverse for fldmod(x) is not defined at this point"))
end
end


# check if T is a real-Number type
# this is not the same as T <: Real which immediately excludes custom Number subtypes such as unitful numbers
# also, isreal(x) != is_real_type(typeof(x)): the former is true for complex numbers with zero imaginary part
is_real_type(@nospecialize _::Type{<:Real}) = true
is_real_type(::Type{T}) where {T<:Number} = real(T) == T
is_real_type(_) = false
aplavin marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 20 additions & 0 deletions test/test_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Test
using InverseFunctions
using Unitful
using Dates


Expand Down Expand Up @@ -79,14 +80,19 @@ end
end

# ensure that inverses have domains compatible with original functions
@test_throws DomainError inverse(sqrt)(-1.0)
InverseFunctions.test_inverse(sqrt, complex(-1.0))
InverseFunctions.test_inverse(sqrt, complex(1.0))
@test_throws DomainError inverse(Base.Fix1(*, 0))
@test_throws DomainError inverse(Base.Fix2(^, 0))
@test_throws DomainError inverse(Base.Fix1(log, -2))(5)
@test_throws DomainError inverse(Base.Fix1(log, 2))(-5)
InverseFunctions.test_inverse(inverse(Base.Fix1(log, 2)), complex(-5))
@test_throws DomainError inverse(Base.Fix2(^, 0.5))(-5)
@test_throws DomainError inverse(Base.Fix2(^, 0.51))(complex(-5))
@test_throws DomainError inverse(Base.Fix2(^, 2))(complex(-5))
InverseFunctions.test_inverse(Base.Fix2(^, 0.5), complex(-5))
InverseFunctions.test_inverse(Base.Fix2(^, -1), complex(-5.))
@test_throws DomainError inverse(Base.Fix2(^, 2))(-5)
@test_throws DomainError inverse(Base.Fix1(^, 2))(-5)
@test_throws DomainError inverse(Base.Fix1(^, -2))(3)
Expand Down Expand Up @@ -130,6 +136,20 @@ end
end
end

@testset "unitful" begin
# the majority of inverse just propagate to underlying mathematical functions and don't have any issues with unitful numbers
# only those that behave treat real numbers differently have to be tested here
x = rand()u"m"
InverseFunctions.test_inverse(sqrt, x)
@test_throws DomainError inverse(sqrt)(-x)

InverseFunctions.test_inverse(Base.Fix2(^, 2), x)
@test_throws DomainError inverse(Base.Fix2(^, 2))(-x)
InverseFunctions.test_inverse(Base.Fix2(^, 3), x)
InverseFunctions.test_inverse(Base.Fix2(^, 3), -x)
InverseFunctions.test_inverse(Base.Fix2(^, -3.5), x)
end

@testset "dates" begin
InverseFunctions.test_inverse(Dates.date2epochdays, Date(2020, 1, 2); compare = ===)
InverseFunctions.test_inverse(Dates.datetime2epochms, DateTime(2020, 1, 2, 12, 34, 56); compare = ===)
Expand Down
Loading