diff --git a/Assets/Scripts/Agents/AgentUtils.cs b/Assets/Scripts/Agents/AgentUtils.cs new file mode 100644 index 0000000..2247d00 --- /dev/null +++ b/Assets/Scripts/Agents/AgentUtils.cs @@ -0,0 +1,40 @@ +using System; +using UnityEngine; + +namespace DialogosEngine +{ + public static class AgentUtils + { + public const string k_EndOfSequence = ""; + + public static float CalculateEchoReward(string expectedString, string guessedString) + { + // Validate input strings + if (string.IsNullOrEmpty(expectedString) || string.IsNullOrEmpty(guessedString)) + { + throw new ArgumentException("Expected and guessed strings must not be null or empty."); + } + if (!expectedString.EndsWith(k_EndOfSequence)) + { + throw new ArgumentException("Expected string must end with ''."); + } + + // Calculate Levenshtein distance + int levenshteinDistance = Lexer.LevenshteinDistance(expectedString, guessedString); + float maxStringLength = Mathf.Max(expectedString.Length, guessedString.Length); + float similarityScore = 1f - (float)levenshteinDistance / maxStringLength; + similarityScore = (similarityScore * 2f) - 1f; // Normalize to range [-1, 1] + + // Calculate length match score + float lengthDifference = Mathf.Abs(expectedString.Length - guessedString.Length); + float lengthMatchScore = 1f - Mathf.Min(2f * lengthDifference / maxStringLength, 1f); + lengthMatchScore = (lengthMatchScore * 2f) - 1f; // Normalize to range [-1, 1] + + // Combine similarity and length match scores + float combinedScore = (0.5f * similarityScore) + (0.5f * lengthMatchScore); + + // Ensure the final score is within the range [-1, 1] + return Mathf.Clamp(combinedScore, -1f, 1f); + } + } +} diff --git a/Assets/Scripts/Agents/AgentUtils.cs.meta b/Assets/Scripts/Agents/AgentUtils.cs.meta new file mode 100644 index 0000000..5c3566f --- /dev/null +++ b/Assets/Scripts/Agents/AgentUtils.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 4a59aeab33b698e49891b7d70f7e6b26 \ No newline at end of file diff --git a/Assets/Scripts/Agents/EchoAgent.cs b/Assets/Scripts/Agents/EchoAgent.cs index bee0b2c..106c6b3 100644 --- a/Assets/Scripts/Agents/EchoAgent.cs +++ b/Assets/Scripts/Agents/EchoAgent.cs @@ -2,51 +2,40 @@ using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; -using UnityEngine; namespace DialogosEngine { public class EchoAgent : Agent { - string _CachedGuessedString; - const string k_EndOfSequence = ""; + string _CachedString; public override void OnEpisodeBegin() { - ClearConsole(); + Terminal.Instance.Buffer.Reset(); } public void FixedUpdate() { string expectedString = GetExpectedString(); - if (_CachedGuessedString != null) + if (_CachedString != null) { - float _reward; - if (_CachedGuessedString.EndsWith(k_EndOfSequence)) + float reward = AgentUtils.CalculateEchoReward(expectedString, _CachedString); + + if (_CachedString.EndsWith(AgentUtils.k_EndOfSequence)) { - _reward = CalculateReward(expectedString, _CachedGuessedString); - - _CachedGuessedString = _CachedGuessedString.Replace(k_EndOfSequence, ""); + _CachedString = _CachedString.Replace(AgentUtils.k_EndOfSequence, ""); - Terminal.Instance.Shell.Run(_CachedGuessedString); - - if(Terminal.Instance.IssuedError) - { - _reward -= 0.5f; // Penalize for bad commands - } - } - else - { - _reward = -1f; + Terminal.Instance.Shell.Run(_CachedString); } - SetReward(_reward); - _CachedGuessedString = null; + + SetReward(reward); + _CachedString = null; } } public override void CollectObservations(VectorSensor sensor) { - string _buffer = GetConsoleBuffer(); + string _buffer = Terminal.Instance.Buffer.GetLastLog(); float[] _vectorizedBuffer = Lexer.VectorizeUTF8(_buffer); foreach (var obs in _vectorizedBuffer) { @@ -57,39 +46,12 @@ public override void CollectObservations(VectorSensor sensor) public override void OnActionReceived(ActionBuffers actions) { float[] _actionArray = actions.ContinuousActions.Array; - _CachedGuessedString = Lexer.QuantizeUTF8(_actionArray); - } - - private void ClearConsole() - { - Terminal.Instance.Buffer.Reset(); - } - - private float CalculateReward(string expectedString, string guessedString) - { - int levenshteinDistance = Lexer.LevenshteinDistance(expectedString, guessedString); - float maxStringLength = Mathf.Max(expectedString.Length, guessedString.Length); - float similarityScore = 1f - (float)levenshteinDistance / maxStringLength; - similarityScore = (similarityScore * 2f) - 1f; // Normalize to range [-1, 1] - - float lengthDifference = Mathf.Abs(expectedString.Length - guessedString.Length); - float lengthMatchScore = 1f - Mathf.Min(2f * lengthDifference / maxStringLength, 1f); - lengthMatchScore = (lengthMatchScore * 2f) - 1f; // Normalize to range [-1, 1] - - float combinedScore = (0.5f * similarityScore) + (0.5f * lengthMatchScore); - return Mathf.Clamp(combinedScore, -1f, 1f); // Ensure final score is within [-1, 1] - } - - - private string GetConsoleBuffer() - { - return Terminal.Instance.Buffer.GetLastLog(); + _CachedString = Lexer.QuantizeUTF8(_actionArray); } private string GetExpectedString() { - // Implementation to get the expected string for the current step - return ""; + return "echo hello "; // Testing } } } diff --git a/Tests/AgentUtilsTests.cs b/Tests/AgentUtilsTests.cs new file mode 100644 index 0000000..108545c --- /dev/null +++ b/Tests/AgentUtilsTests.cs @@ -0,0 +1,166 @@ +namespace DialogosEngine.Tests +{ + [TestFixture] + public class AgentUtilsTests + { + [Test] + public void CalculateEchoReward_ShouldReturnPerfectScore_WhenStringsMatchExactly() + { + // Arrange + string expected = "hello"; + string guessed = "hello"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + Assert.That(reward, Is.EqualTo(1f), "The reward should be 1.0f when the expected and guessed strings match exactly."); + TestContext.WriteLine("Test passed: The calculated reward is 1.0f as expected."); + } + + [Test] + public void CalculateEchoReward_ShouldThrowArgumentException_WhenStringsAreEmpty() + { + // Arrange + string expected = ""; + string guessed = ""; + TestContext.WriteLine($"Testing CalculateEchoReward with empty strings."); + + // Act & Assert + var ex = Assert.Throws(() => AgentUtils.CalculateEchoReward(expected, guessed)); + TestContext.WriteLine($"Expected ArgumentException was thrown with message: {ex.Message}"); + + // Log the result + Assert.IsNotNull(ex, "An ArgumentException should be thrown for empty strings."); + TestContext.WriteLine("Test passed: ArgumentException is thrown as expected for empty strings."); + } + + [Test] + public void CalculateEchoReward_ShouldThrowArgumentException_WhenExpectedStringDoesNotEndWithEos() + { + // Arrange + string expected = "hello"; + string guessed = "hello"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string not ending with '': '{expected}'."); + + // Act & Assert + var ex = Assert.Throws(() => AgentUtils.CalculateEchoReward(expected, guessed)); + TestContext.WriteLine($"Expected ArgumentException was thrown with message: {ex.Message}"); + + // Log the result + Assert.IsNotNull(ex, "An ArgumentException should be thrown when the expected string does not end with ''."); + TestContext.WriteLine("Test passed: ArgumentException is thrown as expected when the expected string does not end with ''."); + } + + [Test] + public void CalculateEchoReward_ShouldReturnPartialScore_WhenStringsAreSimilar() + { + // Arrange + string expected = "hello"; + string guessed = "hallo"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + Assert.That(reward, Is.EqualTo(0.9f).Within(0.01f), "The reward should be approximately 0.9f for similar strings."); + TestContext.WriteLine("Test passed: The calculated reward is 0.9f as expected for similar strings."); + } + + [Test] + public void CalculateEchoReward_ShouldReflectSimilarityAndLengthMatch_WhenEosIsMissing() + { + // Arrange + string expected = "hello"; + string guessed = "hello"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string missing '': '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + Assert.That(reward, Is.EqualTo(-0.5f).Within(0.01f), "The reward should be -0.5f when the guessed string is missing the '' token, reflecting the similarity and length match."); + TestContext.WriteLine("Test passed: The calculated reward is -0.5f as expected, reflecting the similarity and length match."); + } + + [Test] + public void CalculateEchoReward_ShouldReturnSpecificLowerScore_WhenGuessedStringHasAdditionalCharacters() + { + // Arrange + string expected = "hello"; + string guessed = "hello there"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string with additional characters: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + Assert.That(reward, Is.EqualTo(-0.125f).Within(0.01f), "The reward should be -0.125f when the guessed string has additional characters."); + TestContext.WriteLine("Test passed: The calculated reward is -0.125f as expected when the guessed string has additional characters."); + } + + [Test] + public void CalculateEchoReward_ShouldReturnSpecificMuchLowerScore_WhenGuessedStringIsSignificantlyLonger() + { + // Arrange + string expected = "hello"; + string guessed = "hello there, how are you doing today?"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and a significantly longer guessed string: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + Assert.That(reward, Is.EqualTo(-0.7619048f).Within(0.01f), "The reward should be -0.7619048f when the guessed string is significantly longer than the expected string."); + TestContext.WriteLine("Test passed: The calculated reward is -0.7619048f as expected when the guessed string is significantly longer than the expected string."); + } + + [Test] + public void CalculateEchoReward_ShouldReturnPositiveScore_WhenGuessedStringIsLongerButSimilar() + { + // Arrange + string expected = "hello fr "; + string guessed = "hello friend"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and a longer but similar guessed string: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + // The expected reward should be positive due to the high similarity despite the additional length. + // Update the expected value based on the logic of your CalculateEchoReward method. + // For example, if the expected reward is now 0.2f based on the new logic, update the test accordingly: + Assert.That(reward, Is.GreaterThan(0f), "The reward should be positive when the guessed string is longer but still similar to the expected string."); + TestContext.WriteLine($"Test passed: The calculated reward is positive as expected when the guessed string is longer but still similar to the expected string."); + } + + [Test] + public void CalculateEchoReward_ShouldHandleEmojisInStrings() + { + // Arrange + string expected = "hello👋"; + string guessed = "hello👋 friend🙂"; + TestContext.WriteLine($"Testing CalculateEchoReward with expected string containing an emoji: '{expected}' and guessed string with additional emoji: '{guessed}'."); + + // Act + float reward = AgentUtils.CalculateEchoReward(expected, guessed); + TestContext.WriteLine($"Calculated reward: {reward}"); + + // Assert + // The expected reward should account for the emojis as part of the string. + // Update the expected value based on the logic of your CalculateEchoReward method. + // For example, if the expected reward is now 0.1f based on the new logic, update the test accordingly: + Assert.That(reward, Is.GreaterThan(0f), "The reward should be positive when the guessed string contains emojis and is similar to the expected string."); + TestContext.WriteLine($"Test passed: The calculated reward is positive as expected when the guessed string contains emojis and is similar to the expected string."); + } + + } +} diff --git a/Tests/CommandArgTests.cs b/Tests/CommandArgTests.cs index 78c594e..f1111fb 100644 --- a/Tests/CommandArgTests.cs +++ b/Tests/CommandArgTests.cs @@ -254,7 +254,5 @@ public static void ParseCommand_EscapedBackslashes_HandlesCorrectly() // Assert Assert.That(result.String, Is.EqualTo(expected), "Escaped backslashes should be handled correctly."); } - - } }