diff --git a/src/main/java/com/ezylang/evalex/functions/basic/AverageFunction.java b/src/main/java/com/ezylang/evalex/functions/basic/AverageFunction.java index 06e15e6d..c726ebd0 100644 --- a/src/main/java/com/ezylang/evalex/functions/basic/AverageFunction.java +++ b/src/main/java/com/ezylang/evalex/functions/basic/AverageFunction.java @@ -21,10 +21,10 @@ import com.ezylang.evalex.parser.Token; import java.math.BigDecimal; import java.math.MathContext; -import java.util.Arrays; /** - * Returns the average (arithmetic mean) of the numeric arguments. + * Returns the average (arithmetic mean) of the numeric arguments, with recursive support for arrays + * too. * * @author oswaldo.bapvic.jr */ @@ -35,12 +35,46 @@ public class AverageFunction extends AbstractMinMaxFunction { public EvaluationValue evaluate( Expression expression, Token functionToken, EvaluationValue... parameterValues) { MathContext mathContext = expression.getConfiguration().getMathContext(); - BigDecimal sum = - Arrays.stream(parameterValues) - .map(EvaluationValue::getNumberValue) - .reduce(BigDecimal.ZERO, BigDecimal::add); - BigDecimal count = BigDecimal.valueOf(parameterValues.length); - BigDecimal average = sum.divide(count, mathContext); + BigDecimal average = average(mathContext, parameterValues); return expression.convertValue(average); } + + private BigDecimal average(MathContext mathContext, EvaluationValue... parameterValues) { + SumAndCount aux = new SumAndCount(); + for (EvaluationValue parameter : parameterValues) { + aux = aux.plus(recursiveSumAndCount(parameter)); + } + + return aux.sum.divide(aux.count, mathContext); + } + + private SumAndCount recursiveSumAndCount(EvaluationValue parameter) { + SumAndCount aux = new SumAndCount(BigDecimal.ZERO, BigDecimal.ZERO); + if (parameter.isArrayValue()) { + for (EvaluationValue element : parameter.getArrayValue()) { + aux = aux.plus(recursiveSumAndCount(element)); + } + return aux; + } + return new SumAndCount(parameter.getNumberValue(), BigDecimal.ONE); + } + + private class SumAndCount { + private final BigDecimal sum; + private final BigDecimal count; + + private SumAndCount() { + this.sum = BigDecimal.ZERO; + this.count = BigDecimal.ZERO; + } + + private SumAndCount(BigDecimal sum, BigDecimal count) { + this.sum = sum; + this.count = count; + } + + private SumAndCount plus(SumAndCount other) { + return new SumAndCount(sum.add(other.sum), count.add(other.count)); + } + } } diff --git a/src/test/java/com/ezylang/evalex/functions/basic/AverageArrayTest.java b/src/test/java/com/ezylang/evalex/functions/basic/AverageArrayTest.java new file mode 100644 index 00000000..9d3e87f2 --- /dev/null +++ b/src/test/java/com/ezylang/evalex/functions/basic/AverageArrayTest.java @@ -0,0 +1,70 @@ +/* + Copyright 2012-2024 Udo Klimaschewski + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +package com.ezylang.evalex.functions.basic; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.ezylang.evalex.EvaluationException; +import com.ezylang.evalex.Expression; +import com.ezylang.evalex.parser.ParseException; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for the AVERAGE function with arrays. + * + * @author oswaldo.bapvic.jr + */ +class AverageArrayTest { + + @Test + void testAverageSingleArray() throws EvaluationException, ParseException { + Integer[] numbers = {1, 2, 3}; + + Expression expression = new Expression("AVERAGE(numbers)").with("numbers", numbers); + + assertThat(expression.evaluate().getNumberValue().doubleValue()).isEqualTo(2); + } + + @Test + void testAverageMultipleArray() throws EvaluationException, ParseException { + Integer[] numbers1 = {1, 2, 3}; + Integer[] numbers2 = {4, 5, 6}; + + Expression expression = + new Expression("AVERAGE(numbers1, numbers2)") + .with("numbers1", numbers1) + .with("numbers2", numbers2); + + assertThat(expression.evaluate().getNumberValue().doubleValue()).isEqualTo(3.5); + } + + @Test + void testAverageMixedArrayNumber() throws EvaluationException, ParseException { + Integer[] numbers = {1, 2, 3}; + + Expression expression = new Expression("AVERAGE(numbers, 4)").with("numbers", numbers); + + assertThat(expression.evaluate().getNumberValue().doubleValue()).isEqualTo(2.5); + } + + @Test + void testAverageNestedArray() throws EvaluationException, ParseException { + Integer[][] numbers = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; + + Expression expression = new Expression("AVERAGE(numbers)").with("numbers", numbers); + assertThat(expression.evaluate().getNumberValue().doubleValue()).isEqualTo(5); + } +}