Skip to content

Commit

Permalink
Minor modifications
Browse files Browse the repository at this point in the history
- Add keyword arguments for the weights
- Modified functions to use `Iterators.map`
- Add more tests
  • Loading branch information
itsdebartha committed Aug 22, 2023
1 parent 2acfbfa commit b8be0a5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
10 changes: 5 additions & 5 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,11 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
end

"""
mean(f, A::AbstractArray, w::AbstractWeights[, dims::Int])
mean(f, A::AbstractArray, w::AbstractWeights[; dims])
Compute the weighted mean of array `A`, after transforming it'S
contents with the function `f`, with weight vector `w` (of type
`AbstractWeights`). If `dim` is provided, compute the
`AbstractWeights`). If `dims` is provided, compute the
weighted mean along dimension `dims`.
# Examples
Expand All @@ -698,13 +698,13 @@ w = rand(n)
mean(√, x, weights(w))
```
"""
mean(f, A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) =
_mean(f.(A), w, dims)
mean(f, A::AbstractArray, w::AbstractWeights; kwargs...) =
mean(collect(Iterators.map(f, A)), w; kwargs...)

function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return mean(f.(A), dims=dims)
return mean(collect(Iterators.map(f, A)), dims=dims)
end

##### Weighted quantile #####
Expand Down
32 changes: 29 additions & 3 deletions test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,48 @@ end
@test mean(, 1:3, f([1.0, 1.0, 0.5])) 1.3120956
@test mean(, [1 + 2im, 4 + 5im], f([1.0, 0.5])) 1.60824421 + 0.88948688im

@test mean(log, [1:3;], f([1.0, 1.0, 0.5])) 0.49698133
@test mean(log, 1:3, f([1.0, 1.0, 0.5])) 0.49698133
@test mean(log, [1 + 2im, 4 + 5im], f([1.0, 0.5])) 1.155407982 + 1.03678427im

@test mean(x -> x^2, [1:3;], f([1.0, 1.0, 0.5])) 3.8
@test mean(x -> x^2, 1:3, f([1.0, 1.0, 0.5])) 3.8
@test mean(x -> x^2, [1 + 2im, 4 + 5im], f([1.0, 0.5])) -5.0 + 16.0im

for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
@test mean(, a, f(wt), dims=1) sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)
@test mean(, a, f(wt), dims=2) sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt)
@test mean(, a, f(wt), dims=3) sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt)
@test mean(, a, f(wt); dims=1) sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)
@test mean(, a, f(wt); dims=2) sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt)
@test mean(, a, f(wt); dims=3) sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt)
@test_throws ErrorException mean(, a, f(wt), dims=4)

@test mean(log, a, f(wt); dims=1) sum(log.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)
@test mean(log, a, f(wt); dims=2) sum(log.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt)
@test mean(log, a, f(wt); dims=3) sum(log.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt)
@test_throws ErrorException mean(log, a, f(wt), dims=4)

@test mean(x -> x^2, a, f(wt); dims=1) sum((a.^2).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)
@test mean(x -> x^2, a, f(wt); dims=2) sum((a.^2).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt)
@test mean(x -> x^2, a, f(wt); dims=3) sum((a.^2).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt)
@test_throws ErrorException mean(log, a, f(wt), dims=4)
end

b = reshape(1.0:9.0, 3, 3)
w = UnitWeights{Float64}(3)
@test mean(, b, w; dims=1) reshape(w, :, 3) * sqrt.(b) / sum(w)
@test mean(, b, w; dims=2) sqrt.(b) * w / sum(w)
@test mean(log, b, w; dims=1) reshape(w, :, 3) * log.(b) / sum(w)
@test mean(log, b, w; dims=2) log.(b) * w / sum(w)
@test mean(x -> x^2, b, w; dims=1) reshape(w, :, 3) * (b.^2) / sum(w)
@test mean(x -> x^2, b, w; dims=2) (b.^2) * w / sum(w)

c = 1.0:9.0
w = UnitWeights{Float64}(9)
@test mean(, c, w) sum(sqrt.(c)) / length(c)
@test_throws DimensionMismatch mean(, c, UnitWeights{Float64}(6))
@test mean(log, c, w) sum(log.(c)) / length(c)
@test_throws DimensionMismatch mean(log, c, UnitWeights{Float64}(6))
@test mean(x -> x^2, c, w) sum(c.^2) / length(c)
@test_throws DimensionMismatch mean(x -> x^2, c, UnitWeights{Float64}(6))
end

@testset "Quantile fweights" begin
Expand Down

0 comments on commit b8be0a5

Please sign in to comment.