diff --git a/src/weights.jl b/src/weights.jl index 80c50796..274213dd 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -683,11 +683,11 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) end """ - mean(f, A::AbstractArray, w::AbstractWeights[, dims::Int]) + mean(f, A::AbstractArray, w::AbstractWeights[; dims]) Compute the weighted mean of array `A`, after transforming it'S contents with the function `f`, with weight vector `w` (of type -`AbstractWeights`). If `dim` is provided, compute the +`AbstractWeights`). If `dims` is provided, compute the weighted mean along dimension `dims`. # Examples @@ -698,13 +698,13 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -mean(f, A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) = - _mean(f.(A), w, dims) +mean(f, A::AbstractArray, w::AbstractWeights; kwargs...) = + mean(collect(Iterators.map(f, A)), w; kwargs...) function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) a = (dims === :) ? length(A) : size(A, dims) a != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) - return mean(f.(A), dims=dims) + return mean(collect(Iterators.map(f, A)), dims=dims) end ##### Weighted quantile ##### diff --git a/test/weights.jl b/test/weights.jl index a5b6c066..8c855fe2 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -275,22 +275,48 @@ end @test mean(√, 1:3, f([1.0, 1.0, 0.5])) ≈ 1.3120956 @test mean(√, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.60824421 + 0.88948688im + @test mean(log, [1:3;], f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, 1:3, f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.155407982 + 1.03678427im + + @test mean(x -> x^2, [1:3;], f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, 1:3, f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ -5.0 + 16.0im + for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test mean(√, a, f(wt), dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(√, a, f(wt), dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(√, a, f(wt), dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test mean(√, a, f(wt); dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(√, a, f(wt); dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(√, a, f(wt); dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) @test_throws ErrorException mean(√, a, f(wt), dims=4) + + @test mean(log, a, f(wt); dims=1) ≈ sum(log.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(log, a, f(wt); dims=2) ≈ sum(log.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(log, a, f(wt); dims=3) ≈ sum(log.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(log, a, f(wt), dims=4) + + @test mean(x -> x^2, a, f(wt); dims=1) ≈ sum((a.^2).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(x -> x^2, a, f(wt); dims=2) ≈ sum((a.^2).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(x -> x^2, a, f(wt); dims=3) ≈ sum((a.^2).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(log, a, f(wt), dims=4) end b = reshape(1.0:9.0, 3, 3) w = UnitWeights{Float64}(3) @test mean(√, b, w; dims=1) ≈ reshape(w, :, 3) * sqrt.(b) / sum(w) @test mean(√, b, w; dims=2) ≈ sqrt.(b) * w / sum(w) + @test mean(log, b, w; dims=1) ≈ reshape(w, :, 3) * log.(b) / sum(w) + @test mean(log, b, w; dims=2) ≈ log.(b) * w / sum(w) + @test mean(x -> x^2, b, w; dims=1) ≈ reshape(w, :, 3) * (b.^2) / sum(w) + @test mean(x -> x^2, b, w; dims=2) ≈ (b.^2) * w / sum(w) c = 1.0:9.0 w = UnitWeights{Float64}(9) @test mean(√, c, w) ≈ sum(sqrt.(c)) / length(c) @test_throws DimensionMismatch mean(√, c, UnitWeights{Float64}(6)) + @test mean(log, c, w) ≈ sum(log.(c)) / length(c) + @test_throws DimensionMismatch mean(log, c, UnitWeights{Float64}(6)) + @test mean(x -> x^2, c, w) ≈ sum(c.^2) / length(c) + @test_throws DimensionMismatch mean(x -> x^2, c, UnitWeights{Float64}(6)) end @testset "Quantile fweights" begin