diff --git a/Source/Parser/Expressions/IndexedVariableExpression.cs b/Source/Parser/Expressions/IndexedVariableExpression.cs index 46075f01..c021db9a 100644 --- a/Source/Parser/Expressions/IndexedVariableExpression.cs +++ b/Source/Parser/Expressions/IndexedVariableExpression.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using RATools.Data; +using System.Collections.Generic; using System.Text; namespace RATools.Parser.Expressions @@ -43,15 +44,16 @@ public override ExpressionBase GetValue(InterpreterScope scope) { StringBuilder builder; ExpressionBase container, index, result; + bool isReference; - GetContainerIndex(scope, out container, out index); + GetContainerIndex(scope, out container, out index, out isReference); switch (container.Type) { case ExpressionType.Dictionary: result = ((DictionaryExpression)container).GetEntry(index); if (result != null) - return result; + break; builder = new StringBuilder(); builder.Append("No entry in dictionary for key: "); @@ -60,7 +62,8 @@ public override ExpressionBase GetValue(InterpreterScope scope) case ExpressionType.Array: // ASSERT: index was validated in GetContainerIndex - return ((ArrayExpression)container).Entries[((IntegerConstantExpression)index).Value]; + result = ((ArrayExpression)container).Entries[((IntegerConstantExpression)index).Value]; + break; case ExpressionType.Error: return container; @@ -74,12 +77,22 @@ public override ExpressionBase GetValue(InterpreterScope scope) builder.Append(')'); return new ErrorExpression(builder.ToString(), Variable); } + + if (isReference && VariableReferenceExpression.CanReference(result.Type)) + { + builder = new StringBuilder(); + AppendString(builder); + result = new VariableReferenceExpression(new VariableDefinitionExpression(builder.ToString()), result); + } + + return result; } public ErrorExpression Assign(InterpreterScope scope, ExpressionBase newValue) { ExpressionBase container, index; - GetContainerIndex(scope, out container, out index); + bool isReference; + GetContainerIndex(scope, out container, out index, out isReference); switch (container.Type) { @@ -108,8 +121,11 @@ public ErrorExpression Assign(InterpreterScope scope, ExpressionBase newValue) return null; } - private void GetContainerIndex(InterpreterScope scope, out ExpressionBase container, out ExpressionBase index) + private void GetContainerIndex(InterpreterScope scope, + out ExpressionBase container, out ExpressionBase index, out bool isReference) { + isReference = false; + if (Index.Type == ExpressionType.FunctionCall) { var expression = (FunctionCallExpression)Index; @@ -131,7 +147,10 @@ private void GetContainerIndex(InterpreterScope scope, out ExpressionBase contai var variableReference = container as VariableReferenceExpression; if (variableReference != null) + { container = variableReference.Expression; + isReference = true; + } var array = container as ArrayExpression; if (array != null) diff --git a/Source/Parser/Expressions/VariableExpression.cs b/Source/Parser/Expressions/VariableExpression.cs index 79585463..454551e1 100644 --- a/Source/Parser/Expressions/VariableExpression.cs +++ b/Source/Parser/Expressions/VariableExpression.cs @@ -1,4 +1,5 @@ using RATools.Data; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -89,7 +90,7 @@ public virtual ExpressionBase GetValue(InterpreterScope scope) // when a parameter is assigned to a variable that is an array or dictionary, // assume it has already been evaluated and pass it by reference. this is magnitudes // more performant, and allows the function to modify the data in the container. - if (value.Type == ExpressionType.Dictionary || value.Type == ExpressionType.Array) + if (VariableReferenceExpression.CanReference(value.Type)) { var reference = scope.GetVariableReference(Name); CopyLocation(reference); @@ -223,5 +224,18 @@ internal override void AppendString(StringBuilder builder) { builder.Append(Variable.Name); } + + internal static bool CanReference(ExpressionType type) + { + switch (type) + { + case ExpressionType.Dictionary: + case ExpressionType.Array: + return true; + + default: + return false; + } + } } } diff --git a/Tests/Parser/Functions/ArrayMapFunctionTests.cs b/Tests/Parser/Functions/ArrayMapFunctionTests.cs index c2d56907..3f483331 100644 --- a/Tests/Parser/Functions/ArrayMapFunctionTests.cs +++ b/Tests/Parser/Functions/ArrayMapFunctionTests.cs @@ -3,6 +3,7 @@ using RATools.Parser.Expressions; using RATools.Parser.Functions; using System.Linq; +using System.Text; namespace RATools.Parser.Tests.Functions { @@ -30,7 +31,7 @@ private static string Evaluate(string input) scope.Context = new AssignmentExpression(new VariableExpression("t"), expr); if (funcCall.Evaluate(scope, out expr)) { - var builder = new System.Text.StringBuilder(); + var builder = new StringBuilder(); expr.AppendString(builder); return builder.ToString(); } @@ -50,6 +51,27 @@ public void TestSimple() Is.EqualTo("[byte(0x000001), byte(0x000002), byte(0x000003)]")); } + [Test] + public void TestNested() + { + var scope = new InterpreterScope(AchievementScriptInterpreter.GetGlobalScope()); + var array = new ArrayExpression(); + array.Entries.Add(new IntegerConstantExpression(1)); + array.Entries.Add(new IntegerConstantExpression(2)); + array.Entries.Add(new IntegerConstantExpression(3)); + var dict = new DictionaryExpression(); + var key = new IntegerConstantExpression(0); + dict.Add(key, array); + scope.DefineVariable(new VariableDefinitionExpression("dict"), dict); + + var result = FunctionTests.Evaluate("array_map(dict[0], a => byte(a))", scope); + Assert.That(result, Is.Not.Null); + + var builder = new StringBuilder(); + result.AppendString(builder); + Assert.That(builder.ToString(), Is.EqualTo("[byte(0x000001), byte(0x000002), byte(0x000003)]")); + } + [Test] public void TestSingleElement() { diff --git a/Tests/Parser/Functions/ArrayPopFunctionTests.cs b/Tests/Parser/Functions/ArrayPopFunctionTests.cs index 110c7820..5fd938ce 100644 --- a/Tests/Parser/Functions/ArrayPopFunctionTests.cs +++ b/Tests/Parser/Functions/ArrayPopFunctionTests.cs @@ -18,39 +18,14 @@ public void TestDefinition() Assert.That(def.Parameters.ElementAt(0).Name, Is.EqualTo("array")); } - private ExpressionBase Evaluate(string input, InterpreterScope scope, string expectedError = null) + private static ExpressionBase Evaluate(string input, InterpreterScope scope) { - var funcDef = new ArrayPopFunction(); - - var expression = ExpressionBase.Parse(new PositionalTokenizer(Tokenizer.CreateTokenizer(input))); - Assert.That(expression, Is.InstanceOf()); - var funcCall = (FunctionCallExpression)expression; - - ExpressionBase error; - var parameterScope = funcCall.GetParameters(funcDef, scope, out error); - - if (expectedError == null) - { - Assert.That(error, Is.Null); - - ExpressionBase result; - Assert.That(funcDef.Evaluate(parameterScope, out result), Is.True); - return result; - } - else - { - if (error == null) - Assert.That(funcDef.Evaluate(parameterScope, out error), Is.False); - - Assert.That(error, Is.InstanceOf()); - - var parseError = (ErrorExpression)error; - while (parseError.InnerError != null) - parseError = parseError.InnerError; - Assert.That(parseError.Message, Is.EqualTo(expectedError)); + return FunctionTests.Evaluate(input, scope); + } - return null; - } + private static void AssertEvaluateError(string input, InterpreterScope scope, string expectedError) + { + FunctionTests.AssertEvaluateError(input, scope, expectedError); } [Test] @@ -81,12 +56,43 @@ public void TestSimple() Assert.That(array.Entries.Count, Is.EqualTo(0)); } + [Test] + public void TestNested() + { + var scope = new InterpreterScope(); + var array = new ArrayExpression(); + array.Entries.Add(new IntegerConstantExpression(1)); + array.Entries.Add(new IntegerConstantExpression(2)); + var dict = new DictionaryExpression(); + var key = new IntegerConstantExpression(0); + dict.Add(key, array); + scope.DefineVariable(new VariableDefinitionExpression("dict"), dict); + + var entry = Evaluate("array_pop(dict[0])", scope); + Assert.That(entry, Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)entry).Value, Is.EqualTo(2)); + Assert.That(array.Entries.Count, Is.EqualTo(1)); + Assert.That(array.Entries[0], Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)array.Entries[0]).Value, Is.EqualTo(1)); + + entry = Evaluate("array_pop(dict[0])", scope); + Assert.That(entry, Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)entry).Value, Is.EqualTo(1)); + Assert.That(array.Entries.Count, Is.EqualTo(0)); + + // empty array always returns 0 + entry = Evaluate("array_pop(dict[0])", scope); + Assert.That(entry, Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)entry).Value, Is.EqualTo(0)); + Assert.That(array.Entries.Count, Is.EqualTo(0)); + } + [Test] public void TestUndefined() { var scope = new InterpreterScope(); - Evaluate("array_pop(arr)", scope, "Unknown variable: arr"); + AssertEvaluateError("array_pop(arr)", scope, "Unknown variable: arr"); } [Test] @@ -97,7 +103,7 @@ public void TestDictionary() dict.Add(new IntegerConstantExpression(1), new StringConstantExpression("One")); scope.DefineVariable(new VariableDefinitionExpression("dict"), dict); - Evaluate("array_push(dict)", scope, "array: Cannot convert dictionary to array"); + AssertEvaluateError("array_pop(dict)", scope, "array: Cannot convert dictionary to array"); } [Test] diff --git a/Tests/Parser/Functions/ArrayPushFunctionTests.cs b/Tests/Parser/Functions/ArrayPushFunctionTests.cs index df7d91f8..4078245c 100644 --- a/Tests/Parser/Functions/ArrayPushFunctionTests.cs +++ b/Tests/Parser/Functions/ArrayPushFunctionTests.cs @@ -1,5 +1,4 @@ -using Jamiras.Components; -using NUnit.Framework; +using NUnit.Framework; using RATools.Data; using RATools.Parser.Expressions; using RATools.Parser.Expressions.Trigger; @@ -21,37 +20,14 @@ public void TestDefinition() Assert.That(def.Parameters.ElementAt(1).Name, Is.EqualTo("value")); } - private void Evaluate(string input, InterpreterScope scope, string expectedError = null) + private static void Evaluate(string input, InterpreterScope scope) { - var funcDef = new ArrayPushFunction(); - - var expression = ExpressionBase.Parse(new PositionalTokenizer(Tokenizer.CreateTokenizer(input))); - Assert.That(expression, Is.InstanceOf()); - var funcCall = (FunctionCallExpression)expression; - - ExpressionBase error; - var parameterScope = funcCall.GetParameters(funcDef, scope, out error); - - if (expectedError == null) - { - Assert.That(error, Is.Null); - - ExpressionBase result; - Assert.That(funcDef.Evaluate(parameterScope, out result), Is.True); - Assert.That(result, Is.Null); - } - else - { - if (error == null) - Assert.That(funcDef.Evaluate(parameterScope, out error), Is.False); - - Assert.That(error, Is.InstanceOf()); - - var parseError = (ErrorExpression)error; - while (parseError.InnerError != null) - parseError = parseError.InnerError; - Assert.That(parseError.Message, Is.EqualTo(expectedError)); - } + FunctionTests.Execute(input, scope); + } + + private static void AssertEvaluateError(string input, InterpreterScope scope, string expectedError) + { + FunctionTests.AssertEvaluateError(input, scope, expectedError); } [Test] @@ -74,12 +50,35 @@ public void TestSimple() Assert.That(((StringConstantExpression)array.Entries[1]).Value, Is.EqualTo("2")); } + [Test] + public void TestNested() + { + var scope = new InterpreterScope(); + var array = new ArrayExpression(); + var dict = new DictionaryExpression(); + var key = new IntegerConstantExpression(0); + dict.Add(key, array); + scope.DefineVariable(new VariableDefinitionExpression("dict"), dict); + + Evaluate("array_push(dict[0], 1)", scope); + Assert.That(array.Entries.Count, Is.EqualTo(1)); + Assert.That(array.Entries[0], Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)array.Entries[0]).Value, Is.EqualTo(1)); + + Evaluate("array_push(dict[0], \"2\")", scope); + Assert.That(array.Entries.Count, Is.EqualTo(2)); + Assert.That(array.Entries[0], Is.InstanceOf()); + Assert.That(((IntegerConstantExpression)array.Entries[0]).Value, Is.EqualTo(1)); + Assert.That(array.Entries[1], Is.InstanceOf()); + Assert.That(((StringConstantExpression)array.Entries[1]).Value, Is.EqualTo("2")); + } + [Test] public void TestUndefined() { var scope = new InterpreterScope(); - Evaluate("array_push(arr, 1)", scope, "Unknown variable: arr"); + AssertEvaluateError("array_push(arr, 1)", scope, "Unknown variable: arr"); } [Test] @@ -89,10 +88,10 @@ public void TestDictionary() var dict = new DictionaryExpression(); scope.DefineVariable(new VariableDefinitionExpression("dict"), dict); - Evaluate("array_push(dict, 1)", scope, "array: Cannot convert dictionary to array"); + AssertEvaluateError("array_push(dict, 1)", scope, "array: Cannot convert dictionary to array"); } - private void AddHappyFunction(InterpreterScope scope) + private static void AddHappyFunction(InterpreterScope scope) { scope.AddFunction(UserFunctionDefinitionExpression.ParseForTest( "function happy(num1) => num1" diff --git a/Tests/Parser/Functions/FunctionTests.cs b/Tests/Parser/Functions/FunctionTests.cs new file mode 100644 index 00000000..1f30217f --- /dev/null +++ b/Tests/Parser/Functions/FunctionTests.cs @@ -0,0 +1,69 @@ +using Jamiras.Components; +using NUnit.Framework; +using RATools.Parser.Expressions; + +namespace RATools.Parser.Tests.Functions +{ + internal static class FunctionTests + { + private static ExpressionBase CallFunction(string input, InterpreterScope scope) + where T : FunctionDefinitionExpression, new() + { + var functionDefinition = new T(); + + var expression = ExpressionBase.Parse(new PositionalTokenizer(Tokenizer.CreateTokenizer(input))); + Assert.That(expression, Is.InstanceOf()); + var functionCall = (FunctionCallExpression)expression; + scope.Context = functionCall; + + Assert.That(functionCall.FunctionName.Name, Is.EqualTo(functionDefinition.Name.Name)); + + ExpressionBase result; + var parameterScope = functionCall.GetParameters(functionDefinition, scope, out result); + if (result == null) + { + if (!functionDefinition.Evaluate(parameterScope, out result) && result is not ErrorExpression) + result = new ErrorExpression("Failure without ErrorExpression"); + } + + return result; + } + + public static void Execute(string input, InterpreterScope scope) + where T : FunctionDefinitionExpression, new() + { + var result = CallFunction(input, scope); + + var error = result as ErrorExpression; + if (error != null) + Assert.Fail(error.Message); + + Assert.That(result, Is.Null); + } + + public static ExpressionBase Evaluate(string input, InterpreterScope scope) + where T : FunctionDefinitionExpression, new() + { + var result = CallFunction(input, scope); + + var error = result as ErrorExpression; + if (error != null) + Assert.Fail(error.Message); + + Assert.That(result, Is.Not.Null); + return result; + } + + public static void AssertEvaluateError(string input, InterpreterScope scope, string expectedError) + where T : FunctionDefinitionExpression, new() + { + var error = CallFunction(input, scope); + Assert.That(error, Is.InstanceOf()); + + var parseError = (ErrorExpression)error; + while (parseError.InnerError != null) + parseError = parseError.InnerError; + Assert.That(parseError.Message, Is.EqualTo(expectedError)); + } + } +}